Skip to content

Commit

Permalink
Add TFDPR (huggingface#8203)
Browse files Browse the repository at this point in the history
* Create modeling_tf_dpr.py

* Add TFDPR

* Add back TFPegasus, TFMarian, TFMBart, TFBlenderBot

last commit accidentally deleted these 4 lines, so I recover them back

* Add TFDPR

* Add TFDPR

* clean up some comments, add TF input-style doc string

* Add TFDPR

* Make return_dict=False as default

* Fix return_dict bug (in .from_pretrained)

* Add get_input_embeddings()

* Create test_modeling_tf_dpr.py

The current version is already passed all 27 tests!
Please see the test run at : 
https://colab.research.google.com/drive/1czS_m9zy5k-iSJbzA_DP1k1xAAC_sdkf?usp=sharing

* fix quality

* delete init weights

* run fix copies

* fix repo consis

* del config_class, load_tf_weights

They shoud be 'pytorch only'

* add config_class back

after removing it, test failed ... so totally only removing "use_tf_weights = None" on Lysandre suggestion

* newline after .. note::

* import tf, np (Necessary for ModelIntegrationTest)

* slow_test from_pretrained with from_pt=True

At the moment we don't have TF weights (since we don't have official official TF model)
Previously, I did not run slow test, so I missed this bug

* Add simple TFDPRModelIntegrationTest

Note that this is just a test that TF and Pytorch gives approx. the same output.
However, I could not test with the official DPR repo's output yet

* upload correct tf model

* remove position_ids as missing keys

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: patrickvonplaten <patrick@huggingface.co>
  • Loading branch information
3 people authored and fabiocapsouza committed Nov 15, 2020
1 parent aecfa82 commit 1d245ec
Show file tree
Hide file tree
Showing 10 changed files with 1,133 additions and 0 deletions.
19 changes: 19 additions & 0 deletions docs/source/model_doc/dpr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,22 @@ DPRReader

.. autoclass:: transformers.DPRReader
:members: forward

TFDPRContextEncoder
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDPRContextEncoder
:members: call

TFDPRQuestionEncoder
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDPRQuestionEncoder
:members: call


TFDPRReader
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDPRReader
:members: call
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@
DistilBertPreTrainedModel,
)
from .modeling_dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder,
DPRPretrainedContextEncoder,
DPRPretrainedQuestionEncoder,
Expand Down Expand Up @@ -713,6 +716,17 @@
TFDistilBertModel,
TFDistilBertPreTrainedModel,
)
from .modeling_tf_dpr import (
TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDPRContextEncoder,
TFDPRPretrainedContextEncoder,
TFDPRPretrainedQuestionEncoder,
TFDPRPretrainedReader,
TFDPRQuestionEncoder,
TFDPRReader,
)
from .modeling_tf_electra import (
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFElectraForMaskedLM,
Expand Down
22 changes: 22 additions & 0 deletions src/transformers/convert_pytorch_checkpoint_to_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
Expand All @@ -43,6 +46,7 @@
CamembertConfig,
CTRLConfig,
DistilBertConfig,
DPRConfig,
ElectraConfig,
FlaubertConfig,
GPT2Config,
Expand All @@ -59,6 +63,9 @@
TFCTRLLMHeadModel,
TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering,
TFDPRContextEncoder,
TFDPRQuestionEncoder,
TFDPRReader,
TFElectraForPreTraining,
TFFlaubertWithLMHeadModel,
TFGPT2LMHeadModel,
Expand Down Expand Up @@ -98,6 +105,9 @@
CTRLLMHeadModel,
DistilBertForMaskedLM,
DistilBertForQuestionAnswering,
DPRContextEncoder,
DPRQuestionEncoder,
DPRReader,
ElectraForPreTraining,
FlaubertWithLMHeadModel,
GPT2LMHeadModel,
Expand Down Expand Up @@ -147,6 +157,18 @@
BertForSequenceClassification,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"dpr": (
DPRConfig,
TFDPRQuestionEncoder,
TFDPRContextEncoder,
TFDPRReader,
DPRQuestionEncoder,
DPRContextEncoder,
DPRReader,
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"gpt2": (
GPT2Config,
TFGPT2LMHeadModel,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
replace_list_option_in_docstrings,
)
from .configuration_blenderbot import BlenderbotConfig
from .configuration_dpr import DPRConfig
from .configuration_marian import MarianConfig
from .configuration_mbart import MBartConfig
from .configuration_pegasus import PegasusConfig
Expand Down Expand Up @@ -87,6 +88,7 @@
TFDistilBertForTokenClassification,
TFDistilBertModel,
)
from .modeling_tf_dpr import TFDPRQuestionEncoder
from .modeling_tf_electra import (
TFElectraForMaskedLM,
TFElectraForMultipleChoice,
Expand Down Expand Up @@ -192,6 +194,7 @@
(CTRLConfig, TFCTRLModel),
(ElectraConfig, TFElectraModel),
(FunnelConfig, TFFunnelModel),
(DPRConfig, TFDPRQuestionEncoder),
]
)

Expand Down
Loading

0 comments on commit 1d245ec

Please sign in to comment.