Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[RAG] Handle TFIDF retriever with pre-trained model #4436

Merged
merged 3 commits into from
Mar 23, 2022
Merged

Conversation

klshuster
Copy link
Contributor

Patch description
There is currently an issue (#4422) If using a pre-trained RAG/FiD model from the zoo with a TFIDF retriever. This fix includes the query_encoder in the TFIDF retriever state dict.

Testing steps
CI

$ pytest tests/nightly/gpu/test_rag.py  -x
================================================================================ test session starts ================================================================================
platform linux -- Python 3.7.9, pytest-5.3.2, py-1.10.0, pluggy-0.13.1
rootdir: /private/home/kshuster/ParlAI, inifile: pytest.ini
plugins: hydra-core-1.1.0, requests-mock-1.7.0, regressions-2.1.1, datadir-1.3.1
collected 43 items

tests/nightly/gpu/test_rag.py ...........................................                                                                                                     [100%]

============================================================================= slowest 10 test durations =============================================================================
69.95s call     tests/nightly/gpu/test_rag.py::TestRagZooModels::test_bart_rag_dpr_poly
67.92s call     tests/nightly/gpu/test_rag.py::TestRagTfidf::test_rag_token
53.12s call     tests/nightly/gpu/test_rag.py::TestFidZooModels::test_bart_fid_rag_dpr_poly
52.52s call     tests/nightly/gpu/test_rag.py::TestRagDprPoly::test_rag_sequence
50.83s call     tests/nightly/gpu/test_rag.py::TestLoadDPRModel::test_load_dpr
47.06s call     tests/nightly/gpu/test_rag.py::TestRagDprPoly::test_rag_turn
45.98s call     tests/nightly/gpu/test_rag.py::TestRagZooModels::test_bart_rag_sequence
45.39s call     tests/nightly/gpu/test_rag.py::TestRagZooModels::test_bart_rag_turn_do
45.07s call     tests/nightly/gpu/test_rag.py::TestRagZooModels::test_bart_rag_turn_dtt
44.52s call     tests/nightly/gpu/test_rag.py::TestFidZooModels::test_bart_fid_dpr
=================================================================== 43 passed, 24 warnings in 1095.44s (0:18:15) ====================================================================

@@ -553,6 +553,18 @@ def _set_text_vec(
"""
return self._generation_agent._set_text_vec(self, obs, history, truncate)

def _add_generation_metrics(self, batch: Batch, preds: List[torch.Tensor]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't the RagAgent inherit this from the TGA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i needed to make a change to not use batch.batchsize, since that fails for RAG Sequence.

I suppose the change could actually just be made at the TGA level

Copy link
Contributor

@jxmsML jxmsML left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! 🚀

@klshuster klshuster merged commit 4439f94 into main Mar 23, 2022
@klshuster klshuster deleted the fix_rag_tfidf branch March 23, 2022 13:57
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants