Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[RAD] OSS RAG & FiD #3611

Merged
merged 18 commits into from
Apr 28, 2021
Merged

[RAD] OSS RAG & FiD #3611

merged 18 commits into from
Apr 28, 2021

Conversation

klshuster
Copy link
Contributor

@klshuster klshuster commented Apr 21, 2021

Patch description
RAG and FiD, in ParlAI. There's too much to describe in the PR Description, so I would direct readers to the included READMEs for more detailed instructions. I'll give a quick list here of what's encompassed by these changes:

  1. RAG and FiD are now agents in ParlAI -> --model rag, --model fid
  2. Included are scripts for generating FAISS indices to use with RAG and FiD
  3. Included are build files for eight pre-trained models from this project
  4. Included are 31 CI tests that ensure all appropriate combinations of RAG/FiD can be run, and which also ensure that the pre-trained models load appropriately
  5. Included is a DropoutPolyencoder, the base model of the PolyFAISS method described in (this paper)[https://arxiv.org/abs/2104.07567]
  6. Included are comprehensive README additions (for RAG) and updates (for RAD)
  7. Included is an opt_preset file for BART-Large.

The following RAG Options are implemented:

  1. Generation Models: BART, T5, Transformer/Generator (the last is good for e.g. BlenderBot)
  2. Model Types: RAG Sequence, RAG Turn, RAG Token, RAG-ReGReT
  3. Retrievers: DPR, TFIDF, DPR-Poly, PolyFAISS
  4. Indexes: Exact, Compressed

Testing steps
CI


$ pytest -x test_rag.py
==================test session starts ==================
platform linux -- Python 3.7.9, pytest-6.2.1, py-1.10.0, pluggy-1.0.0.dev0
rootdir: /private/home/kshuster/ParlAI, configfile: pytest.ini
plugins: hydra-core-1.0.0, requests-mock-1.8.0, regressions-2.1.1, datadir-1.3.1
collected 31 items

test_rag.py ...............................                                                                                                                                   [100%]

================== slowest 10 durations ==================
257.50s call     tests/nightly/gpu/test_rag.py::TestRagDpr::test_t5_rag_turn
183.55s call     tests/nightly/gpu/test_rag.py::TestRagDprPoly::test_bart_rag_turn
131.02s call     tests/nightly/gpu/test_rag.py::TestRagDpr::test_reddit_rag_turn
101.48s call     tests/nightly/gpu/test_rag.py::TestRagDpr::test_bart_rag_turn
90.21s call     tests/nightly/gpu/test_rag.py::TestRagDpr::test_t5_rag_sequence
67.14s call     tests/nightly/gpu/test_rag.py::TestZooModels::test_bart_rag_dpr_poly
62.91s call     tests/nightly/gpu/test_rag.py::TestRagDprPoly::test_bart_rag_sequence
55.08s call     tests/nightly/gpu/test_rag.py::TestZooModels::test_bart_fid_rag_dpr_poly
53.19s call     tests/nightly/gpu/test_rag.py::TestZooModels::test_bart_rag_sequence
52.52s call     tests/nightly/gpu/test_rag.py::TestZooModels::test_bart_rag_turn_do
==================31 passed, 16 warnings in 1711.35s (0:28:31) ==================

parlai/core/torch_ranker_agent.py Outdated Show resolved Hide resolved
parlai/agents/transformer/dropout_poly.py Outdated Show resolved Hide resolved
if self.use_codes:
ctxt_rep, ctxt_rep_mask, _ = self.model(**self._model_context_input(batch))
else:
model = self.model.module if hasattr(self.model, 'module') else self.model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit; comment mentioning DistributedDataParallel

parlai/agents/transformer/dropout_poly.py Outdated Show resolved Hide resolved
parlai/agents/transformer/dropout_poly.py Outdated Show resolved Hide resolved
parlai/agents/rag/dpr_biencoder.py Outdated Show resolved Hide resolved
parlai/agents/rag/retrievers.py Outdated Show resolved Hide resolved
parlai/agents/rag/retrievers.py Outdated Show resolved Hide resolved
parlai/agents/rag/retrievers.py Outdated Show resolved Hide resolved
parlai/agents/rag/retrievers.py Outdated Show resolved Hide resolved
torch.Tensor,
]:
"""
Reorder the encoder states, for bean search.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol bean search

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm tempted to leave this in haha

return dec_inputs # type: ignore


class RagModelInterface(ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this!

parlai/agents/fid/fid.py Outdated Show resolved Hide resolved
Comment on lines 71 to 86
def reorder_encoder_states(
self,
encoder_states: Tuple[torch.Tensor, ...],
indices: Union[List[int], torch.LongTensor],
) -> Tuple[torch.Tensor, torch.Tensor, List[List[Document]], torch.Tensor]:
"""
Reorder the encoder states.

Override TGM.reorder_encoder_states to make sure we only pass enc, mask.

See ``TorchGeneratorModel.reorder_encoder_states`` for a description.
"""
enc, mask, *_ = encoder_states
return TransformerGeneratorModel.reorder_encoder_states(
self, (enc, mask), indices
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be out of the scope of this PR. But I feel like we could eliminate having to override this if encoder_states were a dict or **kwargs instead of a tuple.

In the scope of this PR, you could remove the need for this function by modifying TransformerGeneratorModel.reorder_encoder_states to do enc, mask = encoder_states[:2].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually find the forcing nice, as it causes people to realize they need to handle reordering states. kwargs could hide something that isn't being properly shuffled

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i should probably clarify this docstring, as we're overriding RagModel.reorder_encoder_states, not TGM

parlai/agents/rag/rag.py Outdated Show resolved Hide resolved
parlai/agents/rag/rag.py Outdated Show resolved Hide resolved
parlai/agents/rag/rag.py Outdated Show resolved Hide resolved
parlai/agents/rag/rag.py Outdated Show resolved Hide resolved
parlai/agents/rag/rag.py Outdated Show resolved Hide resolved
"[Fid]: I love Elvis Presley! He is my favorite singer, songwriter, actor, and producer."
),
"example2": (
"parlai eval_model -mf zoo:hallucination/bart_fid_rag_dpr_poly/model -t wizard_of_wikipedia --num-examples 100"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotta remove this and the following --num-examples 100 as these are numbers from full valid set

"example": (
"parlai eval_model -mf zoo:hallucination/bart_rag_token/model --indexer-type exact --path-to-index zoo:hallucination/wow_passages/exact --path-to-dpr-passages zoo:hallucination/wow_passages/wow_articles.paragraphs.tsv -ne 100"
),
"result": ("TODO"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also gotta fill in these results

--batchsize 16 --fp16 True --gradient-clip 0.1 --label-truncate 128 \
--log-every-n-secs 30 --lr-scheduler reduceonplateau --lr-scheduler-patience 1 \
--model-parallel True --optimizer adam --text-truncate 512 --truncate 512 \
-lr 1e-05 -vmm min -veps 0.25 -vme 1000 -vmt ppl -vp 5 \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a note here to open an issue if other options are desired?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option prefixes are available now...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i know i added one for BART. i wasnt sure what to call these though... opt/rag? since these are optimization/training parameters

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can put them in the project folder and specify it with -o projects/hallucination/very_long_name.opt

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, wasn't aware of that, i'll try it out

Copy link
Contributor

@spencerp spencerp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Unless @stephenroller or @moyapchen have any concerns, it feels ready to merge.

Comment on lines +539 to +540
encoder=opt['t5'].get_encoder(),
encoder_class=ParlaiT5Encoder,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty funky. It should be made better by the "swappable subcomponents" change.

@stephenroller
Copy link
Contributor

Looks like tests still don't pass (maybe merge master into this?) but I defer.

@klshuster
Copy link
Contributor Author

i'll try getting tests to pass before i merge this

@klshuster klshuster merged commit aa71230 into master Apr 28, 2021
@klshuster klshuster deleted the rag_oss branch April 28, 2021 01:51
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants