-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[RAD] OSS RAG & FiD #3611
[RAD] OSS RAG & FiD #3611
Conversation
if self.use_codes: | ||
ctxt_rep, ctxt_rep_mask, _ = self.model(**self._model_context_input(batch)) | ||
else: | ||
model = self.model.module if hasattr(self.model, 'module') else self.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit; comment mentioning DistributedDataParallel
parlai/agents/rag/model_types.py
Outdated
torch.Tensor, | ||
]: | ||
""" | ||
Reorder the encoder states, for bean search. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol bean search
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm tempted to leave this in haha
return dec_inputs # type: ignore | ||
|
||
|
||
class RagModelInterface(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love this!
parlai/agents/fid/fid.py
Outdated
def reorder_encoder_states( | ||
self, | ||
encoder_states: Tuple[torch.Tensor, ...], | ||
indices: Union[List[int], torch.LongTensor], | ||
) -> Tuple[torch.Tensor, torch.Tensor, List[List[Document]], torch.Tensor]: | ||
""" | ||
Reorder the encoder states. | ||
|
||
Override TGM.reorder_encoder_states to make sure we only pass enc, mask. | ||
|
||
See ``TorchGeneratorModel.reorder_encoder_states`` for a description. | ||
""" | ||
enc, mask, *_ = encoder_states | ||
return TransformerGeneratorModel.reorder_encoder_states( | ||
self, (enc, mask), indices | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be out of the scope of this PR. But I feel like we could eliminate having to override this if encoder_states were a dict or **kwargs instead of a tuple.
In the scope of this PR, you could remove the need for this function by modifying TransformerGeneratorModel.reorder_encoder_states
to do enc, mask = encoder_states[:2]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually find the forcing nice, as it causes people to realize they need to handle reordering states. kwargs could hide something that isn't being properly shuffled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i should probably clarify this docstring, as we're overriding RagModel.reorder_encoder_states
, not TGM
"[Fid]: I love Elvis Presley! He is my favorite singer, songwriter, actor, and producer." | ||
), | ||
"example2": ( | ||
"parlai eval_model -mf zoo:hallucination/bart_fid_rag_dpr_poly/model -t wizard_of_wikipedia --num-examples 100" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gotta remove this and the following --num-examples 100
as these are numbers from full valid set
"example": ( | ||
"parlai eval_model -mf zoo:hallucination/bart_rag_token/model --indexer-type exact --path-to-index zoo:hallucination/wow_passages/exact --path-to-dpr-passages zoo:hallucination/wow_passages/wow_articles.paragraphs.tsv -ne 100" | ||
), | ||
"result": ("TODO"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also gotta fill in these results
--batchsize 16 --fp16 True --gradient-clip 0.1 --label-truncate 128 \ | ||
--log-every-n-secs 30 --lr-scheduler reduceonplateau --lr-scheduler-patience 1 \ | ||
--model-parallel True --optimizer adam --text-truncate 512 --truncate 512 \ | ||
-lr 1e-05 -vmm min -veps 0.25 -vme 1000 -vmt ppl -vp 5 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a note here to open an issue if other options are desired?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Option prefixes are available now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i know i added one for BART. i wasnt sure what to call these though... opt/rag
? since these are optimization/training parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can put them in the project folder and specify it with -o projects/hallucination/very_long_name.opt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, wasn't aware of that, i'll try it out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Unless @stephenroller or @moyapchen have any concerns, it feels ready to merge.
encoder=opt['t5'].get_encoder(), | ||
encoder_class=ParlaiT5Encoder, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty funky. It should be made better by the "swappable subcomponents" change.
Looks like tests still don't pass (maybe merge master into this?) but I defer. |
i'll try getting tests to pass before i merge this |
Patch description
RAG and FiD, in ParlAI. There's too much to describe in the PR Description, so I would direct readers to the included READMEs for more detailed instructions. I'll give a quick list here of what's encompassed by these changes:
--model rag
,--model fid
DropoutPolyencoder
, the base model of the PolyFAISS method described in (this paper)[https://arxiv.org/abs/2104.07567]opt_preset
file for BART-Large.The following RAG Options are implemented:
Testing steps
CI