Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[RAG] Handle Different DPR Model File with Pre-trained Model #3688

Merged
merged 4 commits into from
Jun 8, 2021

Conversation

klshuster
Copy link
Contributor

Patch description
This patch fixes a bug in which one would like to swap the --dpr-model-file for a pre-trained RAG or FiD model. The underlying issue is that the dpr model weights are loaded during model initialization, PRIOR TO loading the pre-trained model weights. This is fine when the pre-trained model is an underlying seq2seq model (BART, T5, BB), but it is not ok when the pre-trained model is a RAG or FiD model.

The solution is to overwrite the retriever weights in the state_dict with the already loaded dpr model weights.

Testing steps
Included CI testing:

$ pytest -k TestLoadDPRModel
===== test session starts =====
platform linux -- Python 3.7.9, pytest-6.2.1, py-1.10.0, pluggy-1.0.0.dev0
rootdir: /private/home/kshuster/ParlAI, configfile: pytest.ini
plugins: hydra-core-1.0.0, requests-mock-1.8.0, regressions-2.1.1, datadir-1.3.1
collected 96 items / 95 deselected / 1 selected

test_rag.py .                                                                                                                                                                      [100%]

=====slowest 10 durations =====
41.01s call     tests/nightly/gpu/test_rag.py::TestLoadDPRModel::test_load_dpr

(2 durations < 0.005s hidden.  Use -vv to show these durations.)
=====1 passed, 95 deselected, 4 warnings in 43.30s =====

dpr_model='bert_from_parlai_rag',
pretrained_path=RAG_SEQUENCE_ZOO_MODEL,
)
assert not torch.allclose(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not using the unittest functions like assertIsNone and assertTrue etc.?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i've spoken with @stephenroller about this and he's told me that it's all the same to pytest

logging.warning(
f"Overriding DPR Model with {modelzoo_path(opt['datapath'], opt['dpr_model_file'])}"
)
except FileNotFoundError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should there be warning here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm no, since it's supposed to be silent to the user

try:
init_model, _ = self._get_init_model(opt, None)
init_model_opt = Opt.load(f'{init_model}.opt')
override_dpr = modelzoo_path(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't it just compare opt['dpr_model_file'] != init_model_opt['dpr_model_file'], since the rest seems to be the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately I don't think so -> modelzoo_path only modifies the path if it starts with zoo:, otherwise it's a no-op. Someone could, theoretically, pass in the full path to their dpr model file, even if it is the zoo path, so we need to make sure that the reference isn't the same

Copy link
Contributor

@mojtaba-komeili mojtaba-komeili left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing it so quickly.

by the state loading.

NOTE: If `--model-file M` was trained with `--dpr-model-file D`, and
`--dpr-model-file D` is specified *after training* (i.e., in eval/interactive),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought here - should there be an "--override-dpr-model-file-for-eval"?

(This was the initial case that triggered everything so...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this case i mention here was not what triggered things, but with the new implementation it's handled smoothly

See RagAgent._should_override_dpr_model_weights for important note
regarding specifying the *same* dpr model file as was used to train
the model.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: For clarity's sake, might be nice to be explicit about the 3 cases that show up and expected behavior

A -> Init DPR Model for M
B -> DPR Model within M after training M
C -> New DPR Model you're using to override

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment!

@klshuster
Copy link
Contributor Author

i actually think there is a simpler way to do this, implementing now

@klshuster
Copy link
Contributor Author

Re-requesting review as this can be handled better

Copy link
Contributor

@moyapchen moyapchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@klshuster klshuster merged commit 64d3859 into master Jun 8, 2021
@klshuster klshuster deleted the fix_dpr_mf branch June 8, 2021 18:24
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants