EMDR2 is an end-to-end training algorithm developed for the task of open-domain question answering.
Comparison with other approaches
End-to-End Training
Results
To use this repo, we recommend using one of NGC's more recent PyTorch containers.
The image version used in this paper can be pulled with the command docker pull nvcr.io/nvidia/pytorch:20.03-py3
.
An installation of the Nvidia container toolkit may also be required.
There are additional dependencies that are needed to be installed. We have provided a Dockerfile for the same under the directory docker
.
For instance, to build a new docker image (nvcr.io/nvidia/pytorch:20.10-py3-faiss-compiled
) over the base container, please use this command
cd docker sudo docker build -t nvcr.io/nvidia/pytorch:20.10-py3-faiss-compiled .
To run the above image in an interactive mode, please use this command
sudo docker run --ipc=host --gpus all -it --rm -v /mnt/disks:/mnt/disks nvcr.io/nvidia/pytorch:20.10-py3-faiss-compiled bash
, where /mnt/disks
is the directory to be mounted.
We've provided pretrained checkpoints and datasets on Dropbox for use to train models for open-domain QA tasks and dense retrieval.
These files can be downloaded using the wget
command-line utility and the links provided below.
- Wikipedia evidence passages from DPR paper
- Pre-tokenized evidence passages and their titles
- Dataset-specific question-answer pairs
- BERT-large vocabulary file
- Masked Salient Span (MSS) pre-trained retriever
- Masked Salient Span (MSS) pre-trained reader
- Precomputed evidence embedding using MSS retriever: This is a big file with 32 GB size.
We also provide data for Masked Salient Spans training (URL). This file contains around 20M sentences extracted from the Wikipedia passages file and includes the positions of the named entities in the sentences. To obtain these named entities, we used pre-trained OntoNotes-5.0 model provided by the Stanza toolkit.
An example line from the file in jsonlines format is:
{"doc_id": 209, "sent_text": "Karpov 's outstanding classical tournament play has been seriously limited since 1997 , since he prefers to be more involved in the politics of his home country of Russia .",
"bert_ent_pos": [[14, 14], [31, 31]], "linguistic_ent": [["1997", "DATE", 11, 11], ["Russia", "GPE", 28, 28]]}
Here, the fields doc_id
indicates the passage id in the evidence, sent_text
denotes the sentence text as obtained after Stanza tokenization, linguistic_ent
contains the list of named entities in this format: (named entity text, entity type, entity start position, entity end position), and bert_ent_pos
contains the entities start and end positions after BERT tokenization.
For more details on MSS training, please refer to the papers below.
We've provided several scripts for training models for both dense retriever and open-domain QA tasks in examples
directory.
Please ensure to change the data and checkpoint paths in these scripts.
To replicate the answer generation results on the Natural Questions (NQ) dataset, run the script as
bash examples/openqa/emdr2_nq.sh
Similar scripts are provided for TriviaQA, WebQuestions and also for training dense retriever.
For end-to-end training, we used a single node of 16 A100 GPUs with 40GB GPU memory.
In the codebase, the first set of 8 GPUs are used for model training, the second set of 8 GPUs are used for asynchronous evidence embedding, and all the 16 GPUs are used for online retrieval at every step.
The code can also be run on a node with 8 GPUs by disabling asynchronous evidence embedding computation. However, this can lead to some loss in performance.
Dataset | Dev EM | Test EM | Checkpoint | Precomputed Evidence Embedding |
---|---|---|---|---|
Natural Questions | 50.42 | 52.49 | link | link |
TriviaQA | 71.13 | 71.43 | link | link |
WebQuestions | 49.86 | 48.67 | link | link |
To use these checkpoints, please set the variables of CHECKPOINT_PATH
and EMBEDDING_PATH
to point to the above checkpoint and embedding index, respectively.
Also, add the option of --no-load-optim
and remove the options of --emdr2-training --async-indexer --index-reload-interval 500
from the example script, so that it works in inference mode.
As the memory requirement for inference is lower, evaluation can also be performed on 4-8 GPUs.
- Sometimes, we need to save the retriever model for tasks such as top-K recall evaluation. To just save the retriever model from the checkpoints, please use this cmd
python tools/save_emdr2_models.py --submodel-name retriever --load e2eqa/trivia --save e2eqa/trivia/retriever/
- To create evidence embeddings from a retriever checkpoint and perform top-K recall evaluation, please use this script. Make sure to correctly set the paths of datasets and checkpoints.
bash examples/helper-scripts/create_wiki_indexes_and_evaluate.sh
For any errors or bugs in the codebase, please either open a new issue or send an email to Devendra Singh Sachan (sachan.devendra@gmail.com) .
If you find these codes or data useful, please consider citing our paper as:
@inproceedings{sachan2021endtoend,
title={End-to-End Training of Multi-Document Reader and Retriever for Open-Domain Question Answering},
author={Devendra Singh Sachan and Siva Reddy and William L. Hamilton and Chris Dyer and Dani Yogatama},
booktitle={Advances in Neural Information Processing Systems},
editor={A. Beygelzimer and Y. Dauphin and P. Liang and J. Wortman Vaughan},
year={2021},
url={https://openreview.net/forum?id=5KWmB6JePx}
}
Some of the ideas and implementations in this work were based on a previous paper. Please also consider citing the following paper, if the code is helpful.
@inproceedings{sachan-etal-2021-end,
title = "End-to-End Training of Neural Retrievers for Open-Domain Question Answering",
author = "Sachan, Devendra and Patwary, Mostofa and Shoeybi, Mohammad and Kant, Neel and Ping, Wei and Hamilton, William L. and Catanzaro, Bryan",
booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
month = aug,
year = "2021",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2021.acl-long.519",
doi = "10.18653/v1/2021.acl-long.519",
pages = "6648--6662"
}