Skip to content

Commit ac8a287

Browse files
committed
update README & clean code
1 parent f0dc676 commit ac8a287

File tree

10 files changed

+250
-23
lines changed

10 files changed

+250
-23
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Data Files
2-
dataset/
2+
dataset/*
3+
!dataset/README.md
4+
!dataset/*/demos.json
35
saved_checkpoints/
46
wandb/
57
qa_results/

README.md

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,94 @@
1-
# Coming soon...
1+
<h1 align="center">
2+
InstructRAG
3+
</h1>
4+
5+
<h3 align="center">
6+
Instructing Retrieval-Augmented Generation with Explicit Denoising <br>
7+
[<a href="https://arxiv.org/abs/2406.13629">arXiv</a>] [<a href="https://arxiv.org/abs/2406.13629">Website</a>] [<a href="https://huggingface.co/meng-lab/TriviaQA-InstructRAG-FT">Model</a>] [<a href="https://huggingface.co/datasets/meng-lab/InstructRAG">Dataset</a>] [<a href="https://x.com/weizhepei/status/1803992285899620837">X Summary</a>]
8+
</h3>
9+
10+
InstructRAG is a simple yet effective RAG framework that allows LMs to explicitly denoise retrieved contents by generating rationales for better verifiability and trustworthiness.
11+
12+
![](https://weizhepei.com/instruct-rag-page/static/images/instructrag.pdf)
13+
14+
## **InstructRAG Key Features:**
15+
16+
- 🤖 **Self-Synthesis**: InstructRAG leverages instruction-tuned LMs to generate their OWN supervision for denoising.
17+
- 🔌 **Easy-to-Use**: InstructRAG can be applied in both in-context learning (ICL) and supervised fine-tuning (SFT).
18+
- 🚀 **Effectiveness**: Up to 8.3% better results across 5 benchmarks (Table [5](https://arxiv.org/html/2406.13629v1#S3.T5)).
19+
- 💪 **Noise Robustness**: InstructRAG is robust to increased noise ratios in both training-free and trainable scenarios (Figure [3](https://arxiv.org/html/2406.13629v1#S3.F3)).
20+
- 🔁 **Task Transferability**: InstructRAG can solve out-of-domain unseen tasks (Figure [4](https://arxiv.org/html/2406.13629v1#S3.F4)).
21+
22+
Please see also our [paper](https://arxiv.org/abs/2406.13629) and [X summary](https://x.com/weizhepei/status/1803992285899620837) for more details.
23+
24+
## 🔗 Quick Links
25+
- [InstructRAG: Instructing Retrieval-Augmented Generation with Explicit Denoising](#instructrag-key-features)
26+
- [Installation](#installation)
27+
- [Training Script](#training-script)
28+
- [Evaluation](#evaluation)
29+
- [Generation Example](#generation-example)
30+
- [Model Checkpoints](#model-checkpoints)
31+
32+
## Installation
33+
The following script will create an Python virtual environment and install all required packages.
34+
```shell
35+
bash setup.sh
36+
```
37+
38+
Alternatively, you can also directly create a conda environment using the provided configuration file.
39+
40+
```shell
41+
conda env create -f environment.yml
42+
```
43+
44+
## Training Script
45+
To train the model (i.e., InstructRAG-FT), just activate the environment and run the following training script. The training config is set for 4xH100 80G GPUs. You may need to adjust NUM_DEVICE and PER_DEVICE_BATCH_SIZE based on your computation environment.
46+
47+
```shell
48+
conda activate instrag
49+
bash train.sh
50+
```
51+
## Evaluation
52+
There are two instantiations of our framework:
53+
- InstructRAG-ICL: training-free & easy-to-adapt
54+
- InstructRAG-FT: trainable & better performance
55+
56+
Use the following script to evaluate InstructRAG in both training-free and trainable settings. You can specify the task and model by adjusting DATASET and MODEL in `eval.sh`.
57+
58+
```shell
59+
conda activate instrag
60+
bash eval.sh
61+
```
62+
63+
64+
## Generation Example
65+
66+
The following case study shows that InstructRAG can effectively identify relevant information from noisy input and leverage its own knowledge to correctly answer questions when required. The red texts denote irrelevant or inaccurate model generations, while the green texts denote contents relevant to the question.
67+
68+
![](https://weizhepei.com/instruct-rag-page/static/images/case_study.pdf)
69+
70+
71+
## Model Checkpoints
72+
Below is the full list of InstructRAG models fine-tuned on each dataset in our work.
73+
74+
| Dataset | HF Model Repo | Retriever |
75+
|------------------------------|-----------------------------------------------------------------------------------------------------------|:------:|
76+
| PopQA | [meng-lab/PopQA-InstructRAG-FT](https://huggingface.co/meng-lab/PopQA-InstructRAG-FT) | Contriever |
77+
| TriviaQA | [meng-lab/TriviaQA-InstructRAG-FT](https://huggingface.co/meng-lab/TriviaQA-InstructRAG-FT) | Contriever |
78+
| Natural Questions | [meng-lab/NaturalQuestions-InstructRAG-FT](https://huggingface.co/meng-lab/NaturalQuestions-InstructRAG-FT) | DPR |
79+
| ASQA | [meng-lab/ASQA-InstructRAG-FT](https://huggingface.co/meng-lab/ASQA-InstructRAG-FT) | GTR |
80+
| 2WikiMultiHopQA | [meng-lab/2WikiMultiHopQA-InstructRAG-FT](https://huggingface.co/meng-lab/2WikiMultiHopQA-InstructRAG-FT) | BM25 |
81+
82+
## Bugs or Questions?
83+
If you have any questions related to the code or the paper, feel free to email Zhepei (zhepei.wei@virginia.edu). If you encounter any problems when using the code, or want to report a bug, feel free to open an issue! Please try to specify the problem with details so we can help you better and quicker!
84+
85+
## Citation
86+
Please cite our paper if you find the repo helpful in your work:
87+
88+
```bibtex
89+
@article{wei2024instructrag,
90+
title={{InstructRAG}: Instructing Retrieval-Augmented Generation with Explicit Denoising},
91+
author={Wei, Zhepei and Chen, Wei-Lin and Meng, Yu},
92+
year={2024}
93+
}
94+
```

dataset/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
## Dataset
2+
The datasets (augmented with retirevd documents) used in our work can be downdoaded from our HF dataset repo: [meng-lab/InstructRAG](https://huggingface.co/datasets/meng-lab/InstructRAG).
3+
4+
5+
Please refer to the [generate_rationale.sh](../generate_rationale.sh) script for detailed instructions on preparing data with your own corpus.

environment.yml

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
name: instrag
2+
channels:
3+
- defaults
4+
dependencies:
5+
- _libgcc_mutex=0.1=main
6+
- _openmp_mutex=5.1=1_gnu
7+
- bzip2=1.0.8=h5eee18b_6
8+
- ca-certificates=2024.3.11=h06a4308_0
9+
- ld_impl_linux-64=2.38=h1181459_1
10+
- libffi=3.4.4=h6a678d5_1
11+
- libgcc-ng=11.2.0=h1234567_1
12+
- libgomp=11.2.0=h1234567_1
13+
- libstdcxx-ng=11.2.0=h1234567_1
14+
- libuuid=1.41.5=h5eee18b_0
15+
- ncurses=6.4=h6a678d5_0
16+
- openssl=3.0.14=h5eee18b_0
17+
- python=3.10.14=h955ad1f_1
18+
- readline=8.2=h5eee18b_0
19+
- sqlite=3.45.3=h5eee18b_0
20+
- tk=8.6.14=h39e8969_0
21+
- tzdata=2024a=h04d1e81_0
22+
- xz=5.4.6=h5eee18b_1
23+
- zlib=1.2.13=h5eee18b_1
24+
- pip:
25+
- accelerate==0.31.0
26+
- aiosignal==1.3.1
27+
- annotated-types==0.7.0
28+
- anyio==4.4.0
29+
- attrs==23.2.0
30+
- certifi==2024.6.2
31+
- charset-normalizer==3.3.2
32+
- click==8.1.7
33+
- cloudpickle==3.0.0
34+
- cmake==3.29.6
35+
- diskcache==5.6.3
36+
- dnspython==2.6.1
37+
- einops==0.8.0
38+
- email-validator==2.2.0
39+
- exceptiongroup==1.2.1
40+
- fastapi==0.111.0
41+
- fastapi-cli==0.0.4
42+
- filelock==3.15.4
43+
- flash-attn==2.5.6
44+
- frozenlist==1.4.1
45+
- fsspec==2024.6.0
46+
- h11==0.14.0
47+
- httpcore==1.0.5
48+
- httptools==0.6.1
49+
- httpx==0.27.0
50+
- huggingface-hub==0.23.4
51+
- idna==3.7
52+
- interegular==0.3.3
53+
- jinja2==3.1.4
54+
- joblib==1.4.2
55+
- jsonschema==4.22.0
56+
- jsonschema-specifications==2023.12.1
57+
- lark==1.1.9
58+
- llvmlite==0.43.0
59+
- lm-format-enforcer==0.9.8
60+
- markdown-it-py==3.0.0
61+
- markupsafe==2.1.5
62+
- mdurl==0.1.2
63+
- mpmath==1.3.0
64+
- msgpack==1.0.8
65+
- nest-asyncio==1.6.0
66+
- networkx==3.3
67+
- ninja==1.11.1.1
68+
- numba==0.60.0
69+
- numpy==1.26.4
70+
- nvidia-cublas-cu12==12.1.3.1
71+
- nvidia-cuda-cupti-cu12==12.1.105
72+
- nvidia-cuda-nvrtc-cu12==12.1.105
73+
- nvidia-cuda-runtime-cu12==12.1.105
74+
- nvidia-cudnn-cu12==8.9.2.26
75+
- nvidia-cufft-cu12==11.0.2.54
76+
- nvidia-curand-cu12==10.3.2.106
77+
- nvidia-cusolver-cu12==11.4.5.107
78+
- nvidia-cusparse-cu12==12.1.0.106
79+
- nvidia-ml-py==12.555.43
80+
- nvidia-nccl-cu12==2.19.3
81+
- nvidia-nvjitlink-cu12==12.5.40
82+
- nvidia-nvtx-cu12==12.1.105
83+
- orjson==3.10.5
84+
- outlines==0.0.34
85+
- packaging==24.1
86+
- pip==24.0
87+
- prometheus-client==0.20.0
88+
- protobuf==5.27.1
89+
- psutil==6.0.0
90+
- py-cpuinfo==9.0.0
91+
- pydantic==2.7.4
92+
- pydantic-core==2.18.4
93+
- pygments==2.18.0
94+
- python-dotenv==1.0.1
95+
- python-multipart==0.0.9
96+
- pyyaml==6.0.1
97+
- ray==2.30.0
98+
- referencing==0.35.1
99+
- regex==2024.5.15
100+
- requests==2.32.3
101+
- rich==13.7.1
102+
- rpds-py==0.18.1
103+
- safetensors==0.4.3
104+
- scipy==1.13.1
105+
- sentencepiece==0.2.0
106+
- setuptools==69.5.1
107+
- shellingham==1.5.4
108+
- sniffio==1.3.1
109+
- starlette==0.37.2
110+
- sympy==1.12.1
111+
- tiktoken==0.6.0
112+
- tokenizers==0.19.1
113+
- torch==2.2.1
114+
- tqdm==4.66.4
115+
- transformers==4.41.2
116+
- triton==2.2.0
117+
- typer==0.12.3
118+
- typing-extensions==4.12.2
119+
- ujson==5.10.0
120+
- urllib3==2.2.2
121+
- uvicorn==0.30.1
122+
- uvloop==0.19.0
123+
- vllm==0.4.1
124+
- vllm-nccl-cu12==2.18.1.0.4.0
125+
- watchfiles==0.22.0
126+
- websockets==12.0
127+
- wheel==0.43.0
128+
- xformers==0.0.25

eval.sh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
# conda activate instructrag
2-
3-
export DATASET=ASQA
4-
export CACHE_DIR=/p/llmresearch/huggingface/hub
5-
MODEL=InstructRAG-ICL # [InstructRAG-FT, InstructRAG-ICL]
1+
DATASET=PopQA
2+
MODEL=InstructRAG-FT # [InstructRAG-FT, InstructRAG-ICL]
63

74
CUDA_VISIBLE_DEVICES=0 python src/inference.py \
85
--dataset_name $DATASET \
96
--rag_model $MODEL \
107
--n_docs 5 \
118
--output_dir qa_results/${MODEL}/${DATASET}\
12-
--cache_dir $CACHE_DIR \
139
--load_local_model \

generate_rationale.sh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
# conda activate instructrag
2-
3-
export DATASET=ASQA
4-
export CACHE_DIR=/p/llmresearch/huggingface/hub
1+
DATASET=PopQA
52

63
CUDA_VISIBLE_DEVICES=0 python src/inference.py \
74
--dataset_name $DATASET \

requirement.txt

Lines changed: 0 additions & 5 deletions
This file was deleted.

setup.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/bash
2+
3+
# Create a new conda environment with Python 3.10
4+
conda create -n instrag python=3.10 -y
5+
6+
# Activate the new conda environment
7+
conda activate instrag
8+
9+
# Install numpy, vllm, and accelerate
10+
pip install numpy==1.26.4 vllm==0.4.1 accelerate
11+
12+
# Install flash-attn
13+
pip install flash-attn==2.5.6 --no-build-isolation

src/finetune.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class TrainingArguments(transformers.TrainingArguments):
9494
def main():
9595
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
9696
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
97-
97+
training_args.fsdp_config=dict(fsdp_transformer_layer_cls_to_wrap=["LlamaDecoderLayer"])
98+
TrainingArguments.fsdp_config = training_args.fsdp_config
9899
ctx_mgr = common_utils.staggered_object_creation(
99100
local_rank=training_args.local_rank, world_size=training_args.world_size
100101
)
@@ -120,7 +121,7 @@ def main():
120121
truncation_side="left",
121122
use_fast=training_args.use_fast_tokenizer,
122123
)
123-
124+
124125
tokenizer.padding = training_args.padding
125126
if tokenizer.pad_token is None:
126127
tokenizer.pad_token = tokenizer.eos_token

train.sh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
DATASET='ASQA'
1+
DATASET=PopQA
22
PER_DEVICE_BATCH_SIZE=1
33
NUM_DEVICE=4
44
TOTAL_BATCH_SIZE=128
55
GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_DEVICE/$PER_DEVICE_BATCH_SIZE))
66

7-
export WANDB_MODE=offline
8-
97
CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=$NUM_DEVICE src/finetune.py \
108
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
119
--dataset_name $DATASET \
1210
--output_dir saved_checkpoints/InstructRAG-FT/${DATASET} \
1311
--per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
1412
--gradient_accumulation_steps $GRADIENT_ACC_STEPS \
15-
--cache_dir "/p/llmresearch/huggingface/hub" \
1613
--num_train_epochs 2 \
1714
--n_docs 5 \
1815
--learning_rate 2.5e-5 \

0 commit comments

Comments
 (0)