To demonstrate the usage of RAG-FiT data augmentation, we will follow the experimentation presented in the paper. Choosing the ASQA Q&A dataset and the Phi-3 model. We compare a baseline configuration with 4 other configurations:
- Retrieval augmentation using a corpus and inserting the documents in the prompt after the question.
- Similar to (1) but having the model fine-tune on the completions.
- Similar to (1) and adding a Chain-of-Thought instruction for the model to explain its reasoning and format its answer.
- Similar to (3) but having the model fine-tune on the completions while implementing a technique from RAFT where distracting documents are used.
The ASQA dataset has two types of answer: a long answer and lists of short answers (actually list of lists). Additionally, it has some minimal amount of context in the data, so we augment it using a corpus, stored as a vector DB; we use Qdrant.
In order to train configuration (4), we need to have CoT well-reasoned responses as labels, so we use OpenAI GPT4 model to augment a dataset with these synthetic labels.
Notice: all the configurations mentioned here, implementing the experiments done in the paper, are saved in
. They don't run by default, they need to be specified by running:
python -cp configs/paper -cn config-name-without-extension
The first step would be to augment the entire dataset (train, dev) with relevant documents, based on the questions, see processing-asqa-retrieval.yaml. Let's focus on the different steps:
- _target_: ragfit.processing.dataset_loaders.loaders.HFLoader
inputs: train
path: din0s/asqa
split: train
- _target_: ragfit.processing.dataset_loaders.loaders.HFLoader
inputs: dev
path: din0s/asqa
split: dev
We load the train and dev splits, to be used in the pipeline; they will be referred using the inputs
keyword used in this
- _target_: ragfit.processing.local_steps.common_datasets.ASQA
inputs: [train, dev]
We do some minimal processing, related to ASQA, namely column renaming, collecting the short and long answers and
having a consistent scheme, for example: query
, answers
, positive_passages
, etc. Feel free to add your own types
of pre-processing.
Notice the inputs
keyword can accept a list of strings, meaning the step will run over the datasets specified.
- _target_:
inputs: [train, dev]
pipeline_or_yaml_path: ./configs/external/haystack/qdrant.yaml
docs_key: positive_passages
query_key: query
This is the retrieval step. We use the Haystack framework for building RAG pipelines; in this example, the Haystack pipeline is comprised of an embedder and a retriever, connecting the Qdrant using a Qdrant-Haystack integration (all defined in the requirements file). The Haystack pipeline is initialized from the Qdrant.yaml configuration. One can use other frameworks for retrieval, like LangChain, LlamaIndex, or others.
The retrieval step will store the most relevant documents (k=5) in the docs_key
and the query will be defined by the
- _target_: ragfit.processing.local_steps.context.ContextHandler
inputs: [train, dev]
docs_key: positive_passages
In this simple step, the documents retrieved are processed; they have a title and content fields and this step combine these into a single string for every document. This step may be unnecessary, depending on the retrieval mechanism and format.
- _target_: ragfit.processing.global_steps.sampling.Sampler
inputs: [train, dev]
k: 1
input_key: positive_passages
output_key: negative_passages
The Sampler
class deals with sampling examples from the same dataset or others. In order to train the RAFT-based
model on a combination of relevant and distracting documents, we need to collect these distracting documents. Here we
chose to collect positive documents from other examples, to be used as negative documents. The Sampler
is then ran
with k=1, it collects only the positive_passages
from the examples it samples and store them in a new keyword, called
- _target_: ragfit.processing.global_steps.output.OutputData
inputs: [train, dev]
prefix: asqa
Finally we write the two resulting dataset to disk. They represent the retrieval-augmented datasets, ready to be processed for the different tasks.
To run this process:
python -cp configs/paper -cn processing-asqa-retrieval
For the baseline, there is not going to be context, only the question presented to the model. We use instruction-following models that have a chat template built-in. The framework populates the chat template using the inputs and outputs we generate, so we don't need to worry about roles and special tokens. Additionally, the system instruction is specified only during training and inference: it needn't be part of the dataset so these next steps mainly deal with the prompt generation.
These are the interesting steps:
- _target_: ragfit.processing.dataset_loaders.loaders.LocalLoader
inputs: dev
filename: asqa-dev.jsonl
- _target_: ragfit.processing.local_steps.prompter.TextPrompter
inputs: dev
prompt_file: ragfit/processing/prompts/qa-short.txt
output_key: prompt
query: query
We load the locally retrieval-augmented files we generated in the previous section.
The TextPrompter
populates a template file containing placeholders in python format, see the short
template. The step replace the placeholders with variables using a provided
mapping. The result is a string, saved in a keyword called outputs_key
To run this process:
python -cp configs/paper -cn processing-asqa-baseline
Preparing for configurations (1) and (2), we want to augment the examples with the top 5 documents we collected in the first step.
- _target_: ragfit.processing.local_steps.context.DocumentsJoiner
inputs: [train, dev]
docs_key: positive_passages
k: 5
- _target_: ragfit.processing.local_steps.prompter.TextPrompter
inputs: [train, dev]
prompt_file: ragfit/processing/prompts/qa.txt
output_key: prompt
question: query
context: positive_passages
The DocumentJoiner
joins a list of strings and is needed before the TextPrompter
we've seen from the previous
section. We prepare a dev file—for testing the model with retrieved documents—and also a training file, in order
to run fine-tuning. Both configurations will be evaluated on the dev dataset.
To run this process:
python -cp configs/paper -cn processing-asqa-context
We prepare a dev set with CoT reasoning prompt. The configuration will be similar to the Context configuration, however here we use a different prompt template:
- _target_: ragfit.processing.local_steps.prompter.TextPrompter
inputs: dev
prompt_file: ragfit/processing/prompts/cot.txt
output_key: prompt
question: query
context: positive_passages
To run this process:
python -cp configs/paper -cn processing-asqa-cot-dev
In order to train a model on a CoT-based prompt, we need to collect well-reasoned responses; we use GPT4 for that. Additionally, we implement a technique from RAFT where some percentage of the examples have purely distractor documents, in order for the model ability to filter noise. Here are the relevant steps:
- _target_: ragfit.processing.local_steps.raft.RAFTStep
inputs: train
k: 5
raft_p: 0.5
neg_docs_num: 2
output_key: raft_docs
The RAFTStep
implements the logic presented in the paper; the percentage of purely-distractor documents is defined by
. The list of documents, some relevant, some distracting, are saved in a keyword called output_key
- _target_: ragfit.processing.local_steps.context.DocumentsJoiner
inputs: train
docs_key: raft_docs
- _target_: ragfit.processing.local_steps.prompter.TextPrompter
inputs: train
prompt_file: ragfit/processing/prompts/cot.txt
output_key: prompt
question: query
context: raft_docs
The documents are joined into strings; when k:
all documents are used. The prompt used is the same as when building the dev dataset.
Next is interacting with OpeanAI; we implemented an OpenAI class using Azure,
one can implement using other abstractions. The step itself needs the prompt_key
, instruction file and the results are
saved in the answer_key
- _target_: ragfit.processing.local_steps.api.openai.OpenAIChat
inputs: train
prompt_key: prompt
answer_key: generated_answer
instruction: ragfit/processing/prompts/prompt_instructions/qa.txt
api_version: 2024-05-01-preview
model: GPT-4-32k-Bot
To run this process:
python -cp configs/paper -cn processing-asqa-cot-train