Skip to content

Commit

Permalink
Merge pull request #33 from arcee-ai/feat/oss-cli-sdk-support
Browse files Browse the repository at this point in the history
OSS work: sdk and cli support for training retriever, e2e, and qa-gen, license
  • Loading branch information
Ben-Epstein authored Sep 20, 2023
2 parents 7614a21 + 98c72f0 commit 01d4ff8
Show file tree
Hide file tree
Showing 20 changed files with 860 additions and 471 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@shamanez @Jacobsolawetz @ben-epstein @SachiraKuruppu @metric-space
201 changes: 0 additions & 201 deletions .requirements.txt

This file was deleted.

25 changes: 25 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Contributing to DALM

Thanks for helping out! We're excited for your issues and PRs

## Building from local

Building the repo is straightforward. Clone the repo, and install the package. We use [invoke](https://github.com/pyinvoke/invoke) to manage `DALM`
```shell
git clone https://github.com/arcee-ai/DALM.git && cd DALM
pip install invoke
inv install
```
This will install the repo, with its dev dependencies, in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) (for live updates on code changes)

## Format, lint, test
Because we use `invoke`, the following is all you need to prepare for a pr
```shell
inv format # black, ruff
inv lint # black check, ruff check, mypy
inv test # pytest
```

We require 95% test coverage for all PRs.

For more information around our `invoke` commands, see [`tasks.py`](https://github.com/arcee-ai/DALM/blob/main/tasks.py) and our [`pyproject.toml`](https://github.com/arcee-ai/DALM/blob/main/pyproject.toml) configuration
File renamed without changes.
77 changes: 59 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,43 +26,66 @@ For the first time in the literature, we modified the initial RAG-end2end model

- Additionally, we have data processing codes and synthetic data generation code inside the [datasets](https://github.com/arcee-ai/DALM/tree/main/dalm/datasets) folder.

## Code execution
To perform training and evaluation for both the retriever model and the new rag-e2e model, please adhere to the following steps.
# Usage
To perform training and evaluation for both the retriever model and the new rag-e2e model, please adhere to the following steps.

- The setup for training and evaluation can be effortlessly executed provided you possess a [CSV](https://github.com/arcee-ai/DALM/tree/main/dalm/datasets/toy_dataset_train.py) file containing three columns: passage, query, and answer. You can utilize the script [question_answer_generation.py](https://github.com/arcee-ai/DALM/blob/main/dalm/datasets/qa_gen/question_answer_generation.py) to generate this CSV.
- It's important to highlight that the retriever-only training method employs solely the passages and queries, whereas the rag-e2e training code utilizes all three columns.
- In our experiments, we utilize BAAI/bge-large-en as the retriever and employ meta-llama/Llama-2-7b-hf as the generator. It's important to note that this code is designed to be compatible with any embedding model or autoregressive model available in the Hugging Face model repository at https://huggingface.co/models.
## Installation

You can install this repo directly via `pip install indomain`

## Clone the repositary
- `git clone https://github.com/arcee-ai/DALM.git`
- `cd DALM`
Alternatively, for development or research, you can clone and install the repo locally:
```shell
git clone https://github.com/arcee-ai/DALM.git && cd DALM
pip install --upgrade -e .
```
This will install the DALM repo and all necessary dependencies.

## Install the necesarry libraries
Create your desired virtual environment isntall all necasary librries.
- `pip install -r requirements.txt`
Make sure things are installed correctly by running `dalm version`

## Data setup
### tl;dr
You can run `dalm qa-gen <path-to-dataset>` to preprocess your dataset for training. See `dalm qa-gen --help` for more options
<br>If you do not have a dataset, you can start with ours
```shell
dalm qa-gen dalm/datasets/toy_data_train.csv
```
- The setup for training and evaluation can be effortlessly executed provided you possess a [CSV](https://github.com/arcee-ai/DALM/tree/main/dalm/datasets/toy_data_train.csv) file containing two/three columns: `Passage`, `Query` (and `Answer` if running e2e). You can utilize the script [question_answer_generation.py](https://github.com/arcee-ai/DALM/blob/main/dalm/datasets/qa_gen/question_answer_generation.py) to generate this CSV.
- It's important to highlight that the retriever-only training method employs solely the passages and queries, whereas the rag-e2e training code utilizes all three columns.
- In our experiments, we utilize `BAAI/bge-large-en` as the default retriever and employ `meta-llama/Llama-2-7b-hf` as the default generator. The code is designed to be compatible with any embedding model or autoregressive model available in the Hugging Face model repository at https://huggingface.co/models.

## Training

You can leverage our scripts directly if you'd like, or you can use the `dalm` cli. The arguments for both are identical

### Train Retriever Only

Train `BAAI/bge-large-en` retriever with contrastive learning.

```
python dalm/training/retriever_only/train_retriever_only.py
--train_dataset_csv_path ./dalm/datasets/toy_data_train.csv" \
```shell
python dalm/training/retriever_only/train_retriever_only.py \
--train_dataset_csv_path "./dalm/datasets/toy_data_train.csv" \
--model_name_or_path "BAAI/bge-large-en" \
--output_dir "./dalm/training/rag_e2e/retriever_only_checkpoints" \
--use_peft \
--with_tracking \
--report_to all \
--per_device_train_batch_size 150
```
or
```shell
dalm train-retriever-only "BAAI/bge-large-en" "./dalm/datasets/toy_data_train.csv" \
--output-dir "./dalm/training/rag_e2e/retriever_only_checkpoints" \
--use-peft \
--with-tracking \
--report-to all \
--per-device-train-batch-size 150
```

### Train Retriever and Generator Jointly (RAG-e2e)
For all available arguments and options, see `dalm train-retriever-only --help`

### Train Retriever and Generator Jointly (RAG-e2e)
Train `Llama-2-7b` generator jointly with the retriever model `BAAI/bge-large-en`.

```
```shell
python dalm/training/rag_e2e/train_rage2e.py \
--dataset_path "./dalm/datasets/toy_data_train.csv" \
--retriever_name_or_path "BAAI/bge-large-en" \
Expand All @@ -72,6 +95,20 @@ python dalm/training/rag_e2e/train_rage2e.py \
--report_to all \
--per_device_train_batch_size 24
```
or
```shell
dalm train-rag-e2e \
"./dalm/datasets/toy_data_train.csv" \
"BAAI/bge-large-en" \
"meta-llama/Llama-2-7b-hf" \
--output-dir "./dalm/training/rag_e2e/rag_e2e_checkpoints" \
--with-tracking \
--report-to all \
--per-device-train-batch-size 24
```

For all available arguments and options, see `dalm train-rag-e2e --help`

## Evaluation

Here's a summary of evaluation results on evaluating on a 200K line test csv of Patent abstracts
Expand All @@ -86,11 +123,15 @@ To run retriever only eval
(make sure you have the checkpoints in the project root)

```bash
python dalm/eval/eval_retriever_only.py --dataset_path qa_paits_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
python dalm/eval/eval_retriever_only.py --dataset_path qa_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
```

For the e2e eval

```bash
python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_model_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path retriever_only_checkpoints --generator_peft_model_path generator_only_checkpoints
```


## Contributing
See [CONTRIBUTING](https://github.com/arcee-ai/DALM/tree/main/CONTRIBUTING.md)
2 changes: 1 addition & 1 deletion dalm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.0"
__version__ = "0.0.1"
Loading

0 comments on commit 01d4ff8

Please sign in to comment.