-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[RAG] Handle Different DPR Model File with Pre-trained Model #3688
Conversation
dpr_model='bert_from_parlai_rag', | ||
pretrained_path=RAG_SEQUENCE_ZOO_MODEL, | ||
) | ||
assert not torch.allclose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why not using the unittest
functions like assertIsNone
and assertTrue
etc.?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i've spoken with @stephenroller about this and he's told me that it's all the same to pytest
parlai/agents/rag/rag.py
Outdated
logging.warning( | ||
f"Overriding DPR Model with {modelzoo_path(opt['datapath'], opt['dpr_model_file'])}" | ||
) | ||
except FileNotFoundError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should there be warning here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm no, since it's supposed to be silent to the user
parlai/agents/rag/rag.py
Outdated
try: | ||
init_model, _ = self._get_init_model(opt, None) | ||
init_model_opt = Opt.load(f'{init_model}.opt') | ||
override_dpr = modelzoo_path( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couldn't it just compare opt['dpr_model_file'] != init_model_opt['dpr_model_file']
, since the rest seems to be the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unfortunately I don't think so -> modelzoo_path
only modifies the path if it starts with zoo:
, otherwise it's a no-op. Someone could, theoretically, pass in the full path to their dpr model file, even if it is the zoo path, so we need to make sure that the reference isn't the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing it so quickly.
parlai/agents/rag/rag.py
Outdated
by the state loading. | ||
|
||
NOTE: If `--model-file M` was trained with `--dpr-model-file D`, and | ||
`--dpr-model-file D` is specified *after training* (i.e., in eval/interactive), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thought here - should there be an "--override-dpr-model-file-for-eval"?
(This was the initial case that triggered everything so...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this case i mention here was not what triggered things, but with the new implementation it's handled smoothly
See RagAgent._should_override_dpr_model_weights for important note | ||
regarding specifying the *same* dpr model file as was used to train | ||
the model. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: For clarity's sake, might be nice to be explicit about the 3 cases that show up and expected behavior
A -> Init DPR Model for M
B -> DPR Model within M after training M
C -> New DPR Model you're using to override
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a comment!
i actually think there is a simpler way to do this, implementing now |
Re-requesting review as this can be handled better |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
Patch description
This patch fixes a bug in which one would like to swap the
--dpr-model-file
for a pre-trained RAG or FiD model. The underlying issue is that the dpr model weights are loaded during model initialization, PRIOR TO loading the pre-trained model weights. This is fine when the pre-trained model is an underlying seq2seq model (BART, T5, BB), but it is not ok when the pre-trained model is a RAG or FiD model.The solution is to overwrite the retriever weights in the
state_dict
with the already loaded dpr model weights.Testing steps
Included CI testing: