Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

[RAG] Generate Dense Embeddings Fix #3869

Merged
merged 1 commit into from
Jul 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions parlai/agents/rag/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ The default RAG parameters use the `zoo:hallucination/wiki_passages/psgs_w100.ts

### 1a. [**Recommended**] Obtain/Choose a (Pre-trained) DPR Model

The RAG model works **really well** with DPR models as the backbone retrievers; check out the [DPR repository](https://github.com/facebookresearch/DPR) for some pre-trained DPR models (or, train your own!).
The RAG model works **really well** with DPR models as the backbone retrievers; check out the [DPR repository](https://github.com/facebookresearch/DPR) for some pre-trained DPR models (or, train your own!). Alternatively, you can specify a RAG or FiD model with DPR weights (perhaps, e.g., one from the ParlAI model zoo, such as `zoo:hallucination/bart_rag_token/model`).

### 1b. Train your own Dropout Poly-encoder

Expand All @@ -115,12 +115,12 @@ Check `/path/to/ParlAI/data/models/hallucination/wiki_passages/psgs_w100.tsv` fo
Then, you can use the [`generate_dense_embeddings.py`](https://github.com/facebookresearch/ParlAI/blob/master/parlai/agents/rag/scripts/generate_dense_embeddings.py) script to run the following command:

```bash
python generate_dense_embeddings.py -mf /path/to/dpr/model --dpr-model True \
python generate_dense_embeddings.py --model-file /path/to/dpr/model --dpr-model True \
--passages-file /path/to/passages --outfile /path/to/saved/embeddings \
--shard-id <shard_id> --num-shards <num_shards> -bs <batchsize>
```

The `--dpr-model True` flag signifies that the model file you are providing is a DPR model; if you use a Dropout Poly-encoder, set this to `False`. The script will generate embeddings with the DPR model for shard `<shard_id>` of the data, and save two files:
If the provided `--model-file` is either a path to a DPR model or a path to a ParlAI RAG/FiD model, specify `--dpr-model True` so that the script can appropriately extract the DPR weights; if you use a Dropout Poly-encoder, set `--dpr-model` to `False`. The script will generate embeddings with the DPR model for shard `<shard_id>` of the data, and save two files:

- `/path/to/saved/embeddings_<shard_id>`: The concatenated tensor of embeddings
- `/path/to/saved/ids_<shard_id>`: The list of document ids that corresponds to these embeddings.
Expand Down
12 changes: 9 additions & 3 deletions parlai/agents/rag/dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,19 @@ def _get_build_options(cls, opt: Opt):
try:
# determine if loading a RAG model
loaded_opt = Opt.load(f"{query_path}.opt")
if loaded_opt['model'] == 'rag' and loaded_opt['query_model'] in [
document_path = loaded_opt.get('dpr_model_file', document_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this slow things down if you have a fid model and you load document_path later again? In that case does it make sense to move this to the else condition of the following if?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not entirely sure what you're asking here --> we're only getting the document path from the .opt file. We should never actually build a FiD model in this scenario

if loaded_opt['model'] in ['rag', 'fid'] and loaded_opt['query_model'] in [
'bert',
'bert_from_parlai_rag',
]:
query_model = 'bert_from_parlai_rag'
# document model is always frozen
document_path = loaded_opt.get('dpr_model_file', document_path)
if loaded_opt['model'] == 'fid':
# document model is always frozen
# but may be loading a FiD-RAG Model
doc_loaded_opt = Opt.load(
f"{modelzoo_path(opt['datapath'], document_path)}.opt"
)
document_path = doc_loaded_opt.get('dpr_model_file', document_path)

except FileNotFoundError:
pass
Expand Down