Skip to content

Commit

Permalink
Fix typos in README and bugs in RAG example code for end-to-end evalu…
Browse files Browse the repository at this point in the history
…ation and finetuning (#9355)

* fix a bug in eval_batch_retrieval

* should return parser as well as other staticmethod

* remove duplicate argument

* these kwargs are no longer accepted (cause TypeError in self.generator.generate of modeling_rag.py)

* fixed file paths in README

* moved an arg to add_ray_specific_args
  • Loading branch information
yoshitomo-matsubara authored Jan 3, 2021
1 parent c4fd609 commit d944966
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 33 deletions.
20 changes: 10 additions & 10 deletions examples/research_projects/rag/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ test.source
test.target
```

A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):
A sample finetuning command (run ` ./examples/research_projects/rag/finetune_rag.py --help` to list all available options):

```bash
python examples/rag/finetune_rag.py \
python examples/research_projects/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
Expand All @@ -42,7 +42,7 @@ The `base` models initialize the question encoder with [`facebook/dpr-question_e

If you would like to initialize finetuning with a base model using different question encoder and generator architectures, you can build it with a consolidation script, e.g.:
```
python examples/rag/consolidate_rag_checkpoint.py \
python examples/research_projects/rag/consolidate_rag_checkpoint.py \
--model_type rag_sequence \
--generator_name_or_path facebook/bart-large-cnn \
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
Expand Down Expand Up @@ -71,7 +71,7 @@ Also make sure to start the Ray cluster before running fine-tuning.
# Start a single-node Ray cluster.
ray start --head

python examples/rag/finetune_rag.py \
python examples/research_projects/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
Expand Down Expand Up @@ -113,14 +113,14 @@ We demonstrate how to evaluate retrieval against DPR evaluation data. You can do
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
```bash
mkdir output # or wherever you want to save this
python examples/rag/parse_dpr_relevance_data.py \
python examples/research_projects/rag/parse_dpr_relevance_data.py \
--src_path biencoder-nq-dev.json \
--evaluation_set output/biencoder-nq-dev.questions \
--gold_data_path output/biencoder-nq-dev.pages
```
3. Run evaluation:
```bash
python examples/rag/eval_rag.py \
python examples/research_projects/rag/eval_rag.py \
--model_name_or_path facebook/rag-sequence-nq \
--model_type rag_sequence \
--evaluation_set output/biencoder-nq-dev.questions \
Expand All @@ -131,7 +131,7 @@ We demonstrate how to evaluate retrieval against DPR evaluation data. You can do
```
```bash
# EXPLANATION
python examples/rag/eval_rag.py \
python examples/research_projects/rag/eval_rag.py \
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
--evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation
Expand Down Expand Up @@ -159,7 +159,7 @@ Add `--recalculate` parameter to force the script to perform inference from scra
An example e2e evaluation run could look as follows:
```bash
python examples/rag/eval_rag.py \
python examples/research_projects/rag/eval_rag.py \
--model_name_or_path facebook/rag-sequence-nq \
--model_type rag_sequence \
--evaluation_set path/to/test.source \
Expand All @@ -179,14 +179,14 @@ With `use_custom_knowledge_dataset.py` you can build your own knowledge source,

For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
```bash
python examples/rag/use_own_knowledge_dataset.py \
python examples/research_projects/rag/use_own_knowledge_dataset.py \
--csv_path path/to/my_csv \
--output_dir path/to/my_knowledge_dataset \
```

The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
```bash
python examples/rag/finetune_rag.py \
python examples/research_projects/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
Expand Down
2 changes: 0 additions & 2 deletions examples/research_projects/rag/eval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ def evaluate_batch_e2e(args, rag_model, questions):
early_stopping=False,
num_return_sequences=1,
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
clean_up_tokenization=True,
print_docs=args.print_docs,
)
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)

Expand Down
30 changes: 9 additions & 21 deletions examples/research_projects/rag/finetune_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,6 @@ def add_model_specific_args(parser, root_dir):
type=str,
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
)

return parser

@staticmethod
Expand Down Expand Up @@ -486,27 +485,10 @@ def add_retriever_specific_args(parser):
default=False,
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)

parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)
return parser

@staticmethod
def add_ray_specific_args(parser):
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)

# Ray cluster address.
parser.add_argument(
"--ray-address",
Expand All @@ -517,12 +499,18 @@ def add_ray_specific_args(parser):
"cluster. Has no effect if pytorch is used as the distributed "
"retriever.",
)

parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)
return parser


def main(args=None, model=None) -> GenerativeQAModule:

parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
Expand Down

0 comments on commit d944966

Please sign in to comment.