Skip to content

Commit

Permalink
Merge pull request #55 from arcee-ai/feat/eval-cli
Browse files Browse the repository at this point in the history
refactor eval code, create eval cli
  • Loading branch information
Ben-Epstein authored Sep 24, 2023
2 parents 3b45faa + 968a2d4 commit 2245860
Show file tree
Hide file tree
Showing 11 changed files with 558 additions and 384 deletions.
43 changes: 41 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Make sure things are installed correctly by running `dalm version`. On an non-i
You can run `dalm qa-gen <path-to-dataset>` to preprocess your dataset for training. See `dalm qa-gen --help` for more options
<br>If you do not have a dataset, you can start with ours
```shell
# Note - our dataset already has queries and answers, so you don't actually need to run this.
# replace `toy_dataset_train.csv` with your dataset of titles and passages
dalm qa-gen dalm/datasets/toy_data_train.csv
```
- The setup for training and evaluation can be effortlessly executed provided you possess a [CSV](https://github.com/arcee-ai/DALM/tree/main/dalm/datasets/toy_data_train.csv) file containing two/three columns: `Passage`, `Query` (and `Answer` if running e2e). You can utilize the script [question_answer_generation.py](https://github.com/arcee-ai/DALM/blob/main/dalm/datasets/qa_gen/question_answer_generation.py) to generate this CSV.
Expand Down Expand Up @@ -123,14 +125,51 @@ To run retriever only eval
(make sure you have the checkpoints in the project root)

```bash
python dalm/eval/eval_retriever_only.py --dataset_path qa_pairs_test.csv --retriever_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
python dalm/eval/eval_retriever_only.py \
--dataset_path qa_pairs_test.csv \
--retriever_name_or_path "BAAI/bge-large-en" \
--passage_column_name Abstract \
--query_column_name Question \
--retriever_peft_model_path retriever_only_checkpoints
```
or
```bash
dalm eval-retriever qa_pairs_test.csv \
--retriever-name-or-path "BAAI/bge-large-en" \
--passage-column-name Abstract \
--query-column-name Question \
--retriever-peft-model-path retriever_only_checkpoints
```
See `dalm eval-retriever --help` for all available arguments

For the e2e eval

```bash
python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path rag_e2e_checkpoints/retriever --generator_peft_model_path rag_e2e_checkpoints/generator
python dalm/eval/eval_rag.py \
--dataset_path qa_pairs_test_2.csv \
--retriever_name_or_path "BAAI/bge-large-en" \
--generator_name_or_path "meta-llama/Llama-2-7b-hf" \
--passage_column_name Abstract \
--query_column_name Question \
--answer_column_name Answer \
--evaluate_generator \
--query_batch_size 5 \
--retriever_peft_model_path rag_e2e_checkpoints/retriever \
--generator_peft_model_path rag_e2e_checkpoints/generator
```
or
```bash
dalm eval-rag qa_pairs_test.csv \
--retriever-name-or-path "BAAI/bge-large-en" \
--generator-name-or-path "meta-llama/Llama-2-7b-hf" \
--retriever-peft-model-path rag_e2e_checkpoints/retriever \
--generator-peft-model-path rag_e2e_checkpoints/generator \
--passage-column-name Abstract \
--query-column-name Question \
--answer-column-name Answer \
--query-batch-size 5
```
See `dalm eval-rag --help` for all available arguments


## Contributing
Expand Down
2 changes: 1 addition & 1 deletion dalm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.4"
__version__ = "0.0.5"
114 changes: 110 additions & 4 deletions dalm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from dalm import __version__
from dalm.datasets.qa_gen.question_answer_generation import generate_qa_from_disk
from dalm.eval.eval_rag import evaluate_rag
from dalm.eval.eval_retriever_only import evaluate_retriever
from dalm.training.rag_e2e.train_rage2e import train_e2e
from dalm.training.retriever_only.train_retriever_only import train_retriever

Expand All @@ -24,6 +26,11 @@ class DALMSchedulerType(Enum):
CONSTANT_WITH_WARMUP = SchedulerType.CONSTANT_WITH_WARMUP


class TorchDtype(str, Enum):
float16 = "float16"
bfloat16 = "bfloat16"


@cli.command()
def version() -> None:
"""Print the current version of DALM"""
Expand All @@ -35,7 +42,7 @@ def train_rag_e2e(
dataset_path: Annotated[
str,
typer.Argument(
help="Path to the dataset to train with. Can be a huggingface dataset directory or a csv file.",
help="Path to the dataset to train with. Can be an hf dataset dir, csv file, or path to hub file.",
show_default=False,
),
],
Expand Down Expand Up @@ -122,7 +129,7 @@ def train_rag_e2e(
] = True,
use_peft: Annotated[bool, typer.Option(help="Whether to use Peft during fine-tuning.")] = True,
) -> None:
"""End-to-end train an in-domain model, including the retreiver and generator"""
"""End-to-end train an in-domain model, including the retriever and generator"""
train_e2e(
dataset_or_path=dataset_path,
retriever_name_or_path=retriever_name_or_path,
Expand Down Expand Up @@ -163,7 +170,7 @@ def train_retriever_only(
dataset_path: Annotated[
str,
typer.Argument(
help="Path to the train dataset to train with. Can be a huggingface dataset directory or a csv file.",
help="Path to the train dataset to train with. Can be an hf dataset dir, csv file, or path to hub file.",
show_default=False,
),
],
Expand Down Expand Up @@ -231,7 +238,7 @@ def train_retriever_only(
] = True,
use_peft: Annotated[bool, typer.Option(help="Whether to use Peft during fine-tuning.")] = True,
) -> None:
"""End-to-end train an in-domain model, including the retriever and generator"""
"""Train only the retriever using contrastive training"""
train_retriever(
dataset_or_path=dataset_path,
retriever_name_or_path=retriever_name_or_path,
Expand Down Expand Up @@ -293,5 +300,104 @@ def qa_gen(
)


@cli.command()
def eval_rag(
dataset_path: Annotated[
str,
typer.Argument(
help="Path to the dataset to eval with. Can be an hf dataset dir, csv file, or path to hub file.",
show_default=False,
),
],
retriever_name_or_path: Annotated[
str, typer.Option(help="Path to pretrained retriever or identifier from huggingface.co/models.")
],
generator_name_or_path: Annotated[
str, typer.Option(help="Path to pretrained (causal) generator or identifier from huggingface.co/models.")
],
retriever_peft_model_path: Annotated[str, typer.Option(help="Path to the fine-tuned retriever peft layers")],
generator_peft_model_path: Annotated[str, typer.Option(help="Path to the fine-tuned generator peft layers")],
passage_column_name: Annotated[str, typer.Option(help="Name of the column containing the passage")] = "Abstract",
query_column_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question",
answer_column_name: Annotated[str, typer.Option(help="Name of the column containing the Answer")] = "Answer",
embed_dim: Annotated[int, typer.Option(help="Dimension of the model embedding")] = 1024,
max_length: Annotated[
int, typer.Option(help="The max passage sequence length during tokenization. Longer sequences are truncated")
] = 128,
test_batch_size: Annotated[int, typer.Option(help="Batch size (per device) for the test dataloader.")] = 8,
query_batch_size: Annotated[int, typer.Option(help="Batch size (per device) for generator input")] = 16,
device: Annotated[str, typer.Option(help="Device. cpu or cuda.")] = "cuda",
torch_dtype: Annotated[
TorchDtype, typer.Option(help="torch.dtype to use for tensors. float16 or bfloat16.")
] = TorchDtype.float16,
top_k: Annotated[int, typer.Option(help="Top K retrieval")] = 10,
evaluate_generator: Annotated[
bool, typer.Option(help="Enable generator evaluation. If false, equivalent to eval-retriever")
] = True,
) -> None:
"""Evaluate your end-to-end rag generator and retriever"""
evaluate_rag(
dataset_or_path=dataset_path,
retriever_name_or_path=retriever_name_or_path,
generator_name_or_path=generator_name_or_path,
retriever_peft_model_path=retriever_peft_model_path,
generator_peft_model_path=generator_peft_model_path,
passage_column_name=passage_column_name,
query_column_name=query_column_name,
answer_column_name=answer_column_name,
embed_dim=embed_dim,
max_length=max_length,
test_batch_size=test_batch_size,
query_batch_size=query_batch_size,
device=device,
torch_dtype=torch_dtype.value,
# torch_dtype=cast(Literal["float16", "bfloat16"], torch_dtype.value),
top_k=top_k,
evaluate_generator=evaluate_generator,
)


@cli.command()
def eval_retriever(
dataset_path: Annotated[
str,
typer.Argument(
help="Path to the dataset to eval with. Can be an hf dataset dir, csv file, or path to hub file.",
show_default=False,
),
],
retriever_name_or_path: Annotated[
str, typer.Option(help="Path to pretrained retriever or identifier from huggingface.co/models.")
],
retriever_peft_model_path: Annotated[str, typer.Option(help="Path to the fine-tuned retriever peft layers")],
passage_column_name: Annotated[str, typer.Option(help="Name of the column containing the passage")] = "Abstract",
query_column_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question",
embed_dim: Annotated[int, typer.Option(help="Dimension of the model embedding")] = 1024,
max_length: Annotated[
int, typer.Option(help="The max passage sequence length during tokenization. Longer sequences are truncated")
] = 128,
test_batch_size: Annotated[int, typer.Option(help="Batch size (per device) for the test dataloader.")] = 8,
device: Annotated[str, typer.Option(help="Device. cpu or cuda.")] = "cuda",
torch_dtype: Annotated[
TorchDtype, typer.Option(help="torch.dtype to use for tensors. float16 or bfloat16.")
] = TorchDtype.float16,
top_k: Annotated[int, typer.Option(help="Top K retrieval")] = 10,
) -> None:
"""Evaluate your retriever only"""
evaluate_retriever(
dataset_or_path=dataset_path,
retriever_name_or_path=retriever_name_or_path,
retriever_peft_model_path=retriever_peft_model_path,
passage_column_name=passage_column_name,
query_column_name=query_column_name,
embed_dim=embed_dim,
max_length=max_length,
test_batch_size=test_batch_size,
device=device,
torch_dtype=torch_dtype.value,
top_k=top_k,
)


if __name__ == "__main__":
cli()
Loading

0 comments on commit 2245860

Please sign in to comment.