Skip to content

Commit

Permalink
Merge pull request #169 from texttron/tevatron-v2
Browse files Browse the repository at this point in the history
tevatron-v2 update: unified toolkit across scale, language and modality
  • Loading branch information
MXueguang authored Feb 19, 2025
2 parents 498bc9b + ac8a9c7 commit 8afb331
Show file tree
Hide file tree
Showing 24 changed files with 1,501 additions and 97 deletions.
48 changes: 29 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
# Tevatron V1.5
Tevatron aims to provide a flexible and efficient toolkit that enables training and inference for neural retrieval models at scale.
# Tevatron V2.0

> Some of the features in Tevatron v1 is not yet migrated to Tevatron v1.5. We are working on it.
<div align="center">
<a href="https://arxiv.org/abs/2203.05765" target="_blank"><img src=https://img.shields.io/badge/arXiv-b5212f.svg?logo=arxiv></a>
<a href="https://huggingface.co/Tevatron" target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace%20Datasets-27b3b4.svg></a>
<a href="https://opensource.org/license/apache-2-0"><img src="https://img.shields.io/static/v1?label=License&message=Apache-2.0&color=red"></a>
<a href="https://pepy.tech/projects/tevatron"><img src="https://static.pepy.tech/badge/tevatron" alt="PyPI Downloads"></a>
<a href="https://star-history.com/#texttron/tevatron"> <img src="https://img.shields.io/github/stars/texttron/tevatron?style=social" alt="GitHub stars"> </a>
<!-- -->
</div>

Tevatron: Unified Document Retrieval Toolkit across Scale, Language, and Modality.

> Some of the features in Tevatron v1 is not yet migrated to Tevatron v2.0. We are working on it.
> If you are looking for the Tevatron v1 features, please pull the [v1 branch](https://github.com/texttron/tevatron/tree/tevatron-v1).
## Features
- Training billion-scale LLM neural retriever on GPUs and TPUs.
- Parameter efficient tuning with LoRA.
- Integration with DeepSpeed, flash attention, gradient accumulation, and other efficient training techniques.
- Self-contained datasets for neural retrieval and open-domain QA tasks.
- Integration with vLLM, DeepSpeed, FlashAttention, gradient accumulation, and other efficient training and inference techniques.
- Self-contained [huggingface datasets](https://huggingface.co/Tevatron) for multi-modal and multilingual neural retrieval and open-domain QA tasks.
- Direct loading and finetuning SoTA pre-trained models (BGE-Embbedding, Instruct-E5) from HuggingFace.

## Installation
Expand All @@ -21,7 +31,7 @@ Tevatron aims to provide a flexible and efficient toolkit that enables training
```bash
pip install transformers datasets peft
pip install deepspeed accelerate
pip install faiss
pip install faiss-cpu
pip install -e .
```

Expand Down Expand Up @@ -90,15 +100,10 @@ Tevatron takes training or inference data in `jsonl` format with each line organ
```json
{
"query_id": "<query id>",
"query": "<query text>",
"positive_passages": [
{"docid": "<passage id>", "title": "<passage title>", "text": "<passage body>"},
...
],
"negative_passages": [
{"docid": "<passage id>", "title": "<passage title>", "text": "<passage body>"},
...
]
"query_text": "<query text>",
"query_image": "<query image>",
"positive_document_ids": ["<passage id>", ...],
"negative_document_ids": ["<passage id>", ...],
}
```
where the passages in `positive_passages` are the annotated relevant passages of the `query`
Expand All @@ -107,12 +112,14 @@ and passages in `negative_passages` are usually non-relevant (hard negative) pas
#### 2. Corpus Data
```json
{
"docid": "<passage id>",
"title": "<passage title>",
"text": "<passage body>"
"docid": "<document id>",
"document_text": "<document text>",
"document_image": "<document image>",
}
```
where each line represents a passage in the corpus.
where each line represents a document in the corpus.

Note that the image field for both training and corpus data are optional and can be omitted (i.e., pure textual modality retrieval).

### Self-Contained Dataset
Tevatron self-contained several commonlly used datasets for neural retrieval.
Expand Down Expand Up @@ -323,6 +330,9 @@ The output file is in the format of `<query_id> <passage_id> <score>` in each li

</details>

## Examples
+ [Unified multi-modal and multilingual retrieval](./examples/multimodal/README.md)
+ [vLLM encoding and retrieval](./examples/example_repllama_vllm.md)

## Citation
If you find Tevatron helpful, please consider citing our [paper](https://arxiv.org/abs/2203.05765).
Expand Down
65 changes: 65 additions & 0 deletions deepspeed/ds_zero0_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
{
"zero_optimization": {
"stage": 0,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 1e6,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 10,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 10,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto",
"torch_adam": true
}
},

"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},

"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 1000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
4 changes: 2 additions & 2 deletions examples/example_repllama.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode \
--dataset_config scifact \
--dataset_split train \
--encode_output_path beir_embedding_scifact/corpus_scifact.${s}.pkl \
--encode_num_shard 4 \
--encode_shard_index ${s}
--dataset_number_of_shards 4 \
--dataset_shard_index ${s}
done
```

Expand Down
4 changes: 2 additions & 2 deletions examples/example_repllama_vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.vllm_encode \
--dataset_config scifact \
--dataset_split train \
--encode_output_path beir_embedding_scifact/corpus_scifact.${s}.pkl \
--encode_num_shard 4 \
--encode_shard_index ${s}
--dataset_number_of_shards 4 \
--dataset_shard_index ${s}
done
```

Expand Down
205 changes: 205 additions & 0 deletions examples/multimodal/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Unified Multi-modal and Multilingual Retrieval

## Train
```bash
deepspeed --include localhost:0,1,2,3,4,5,6,7,8 --master_port 60000 --module tevatron.retriever.driver.train_mm \
--deepspeed deepspeed/ds_zero0_config.json \
--output_dir retriever-qwen25vl-bge-pixmo-colpali-wiki \
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
--lora \
--lora_target_modules q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj \
--save_steps 500 \
--train_yaml dataset_config.yaml \
--query_prefix "Query: " \
--passage_prefix "" \
--bf16 \
--tf32 True \
--pooling eos \
--append_eos_token \
--normalize \
--temperature 0.02 \
--per_device_train_batch_size 16 \
--gradient_checkpointing \
--train_group_size 4 \
--learning_rate 1e-4 \
--query_max_len 512 \
--passage_max_len 512 \
--num_train_epochs 1 \
--logging_steps 1 \
--overwrite_output_dir \
--gradient_accumulation_steps 2 \
--warmup_ratio 0.005 \
--report_to wandb \
--dataloader_num_workers 4
```

## Inference and evaluation

### BEIR (textual modality)

#### Query Encode
```bash

CKPT=retriever-qwen25vl-bge-pixmo-colpali-wiki
DATASET=scifact

mkdir -p beir_embedding/${CKPT}/${DATASET}
CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode_mm \
--output_dir=temp \
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
--lora_name_or_path ${CKPT} \
--lora \
--bf16 \
--per_device_eval_batch_size 16 \
--normalize \
--pooling last \
--query_prefix "Query: " \
--passage_prefix "" \
--append_eos_token \
--query_max_len 512 \
--dataset_name Tevatron/beir \
--dataset_config ${DATASET} \
--dataset_split test \
--encode_output_path beir_embedding/${CKPT}/${DATASET}/queries.pkl \
--encode_is_query
```

#### Document Encode
```bash
for s in 0 1 2 3;
do
CUDA_VISIBLE_DEVICES=$s python -m tevatron.retriever.driver.encode_mm \
--output_dir=temp \
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
--lora_name_or_path ${CKPT} \
--lora \
--bf16 \
--per_device_eval_batch_size 16 \
--normalize \
--pooling last \
--passage_prefix "" \
--append_eos_token \
--passage_max_len 512 \
--dataset_name Tevatron/beir-corpus \
--dataset_config ${DATASET} \
--dataset_split train \
--encode_output_path beir_embedding/${CKPT}/${DATASET}/corpus.${s}.pkl \
--dataset_number_of_shards 4 \
--dataset_shard_index ${s} &
done
wait
```



#### Search
```bash
mkdir -p beir_results/${CKPT}/scifact
python -m tevatron.retriever.driver.search \
--query_reps beir_embedding/${CKPT}/${DATASET}/queries.pkl \
--passage_reps beir_embedding/${CKPT}/${DATASET}/'corpus.*.pkl' \
--depth 100 \
--batch_size 64 \
--save_text \
--save_ranking_to beir_results/${CKPT}/${DATASET}/rank.scifact.txt
```

#### Evaluate
```bash
python -m tevatron.utils.format.convert_result_to_trec \
--input beir_results/${CKPT}/${DATASET}/rank.scifact.txt \
--output beir_results/${CKPT}/${DATASET}/rank.scifact.trec \
--remove_query

python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 beir-v1.0.0-scifact-test \
beir_results/${CKPT}/${DATASET}/rank.scifact.trec
```

### MIRACL (Multi-Lingual, Textual Modality)
#### Query Encode
```bash

CKPT=retriever-qwen25vl-bge-pixmo-colpali-wiki
DATASET=ar

mkdir -p miracl_embedding/${CKPT}/${DATASET}
CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode_mm \
--output_dir=temp \
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
--lora_name_or_path ${CKPT} \
--lora \
--bf16 \
--per_device_eval_batch_size 16 \
--normalize \
--pooling last \
--query_prefix "Query: " \
--passage_prefix "" \
--append_eos_token \
--query_max_len 512 \
--dataset_name miracl/miracl \
--dataset_config $DATASET \
--dataset_split test \
--encode_output_path miracl_embedding/${CKPT}/${DATASET}/queries.pkl \
--encode_is_query
```

#### Document Encode
```bash
for s in 0 1 2 3;
do
CUDA_VISIBLE_DEVICES=$s python -m tevatron.retriever.driver.encode_mm \
--output_dir=temp \
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
--lora_name_or_path ${CKPT} \
--lora \
--bf16 \
--per_device_eval_batch_size 16 \
--normalize \
--pooling last \
--passage_prefix "" \
--append_eos_token \
--passage_max_len 512 \
--dataset_name miracl/miracl-corpus \
--dataset_config ${DATASET} \
--dataset_split train \
--encode_output_path miracl_embedding/${CKPT}/${DATASET}/corpus.${s}.pkl \
--dataset_number_of_shards 4 \
--dataset_shard_index ${s} &
done
wait
```



#### Search
```bash
mkdir -p miracl_results/retriever-qwen25vl-bge-pixmo-colpali-wiki/$DATASET
python -m tevatron.retriever.driver.search \
--query_reps miracl_embedding/${CKPT}/${DATASET}/queries.pkl \
--passage_reps miracl_embedding/${CKPT}/${DATASET}/'corpus.*.pkl' \
--depth 100 \
--batch_size 64 \
--save_text \
--save_ranking_to miracl_results/${CKPT}/${DATASET}/rank.${DATASET}.txt
```

#### Evaluate
```bash
python -m tevatron.utils.format.convert_result_to_trec \
--input miracl_results/${CKPT}/${DATASET}/rank.${DATASET}.txt \
--output miracl_results/${CKPT}/${DATASET}/rank.${DATASET}.trec

python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 miracl-v1.0-${DATASET}-dev \
miracl_results/${CKPT}/${DATASET}/rank.${DATASET}.trec
```

### VIDORE Document screenshot retrieval (Cross modality)
```bash
CUDA_VISIBLE_DEVICES=0 python eval_vidore.py \
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
--lora_name_or_path ${CKPT} \
--batch_size 4 \
--pooling last \
--normalize \
--query_prefix "Query: "
```
Loading

0 comments on commit 8afb331

Please sign in to comment.