Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add monot5 tpu train doc #108

Merged
merged 3 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,4 @@ The following documents describe how to use Pygaggle on various IR test collecti
+ [Experiments on MS MARCO Document Retrieval](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-document.md)
+ [Experiments on MS MARCO Passage Retrieval - Dev Subset](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-passage-subset.md)
+ [Experiments on MS MARCO Passage Retrieval - Entire Dev Set](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-passage-entire.md)
+ [Experiments on MS MARCO Passage Retrieval - with TPU](https://github.com/castorini/pygaggle/blob/master/docs/experiments-monot5-tpu.md)
59 changes: 59 additions & 0 deletions docs/experiments-monot5-tpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,63 @@ You should see the same result.

If you were able to replicate these results, please submit a PR adding to the replication log! Please mention in your PR if you find any difference!

## Train monoT5

First, download the MS MARCO train triples:
```
wget https://storage.googleapis.com/duobert_git/triples.train.small.tar.gz
tar -xvf triples.train.small.tar.gz
rm triples.train.small.tar.gz
```

Then convert the train triples to t5 input format:
```
python -m pygaggle.data.create_msmarco_t5_training_pairs --triples_train triples.train.small.tsv --output_to_t5 query_doc_pairs.train.tsv
```

Next, copy the input file to Google Storage. TPU training will read data directly from `gs`
```
gsutil cp query_doc_pairs.train.tsv ${GS_FOLDER}/query_doc_pairs.train.tsv
```

Recall the environment variables
```
export MODEL=<t5 pretrain model, e.g. base, large, 3B>
export GS_FOLDER=<gs folder to store checkpoints>
export PROJECT_NAME=<gcloud project name>
export TPU_NAME=<name of tpu to create>
export BASE_CKPT=<initial model checkpoint, e.g. 999900>
```

Copy pre-trained checkpoint to our target model
```
echo "model_checkpoint_path: \"model.ckpt-${BASE_CKPT}\"" > checkpoint
gsutil cp checkpoint ${GS_FOLDER}
gsutil cp gs://t5-data/pretrained_models/${MODEL}/model.ckpt-${BASE_CKPT}* ${GS_FOLDER}
```

```
nohup t5_mesh_transformer \
--tpu="${TPU_NAME}" \
--gcp_project="${PROJECT_NAME}" \
--tpu_zone="europe-west4-a" \
--model_dir="${GS_FOLDER}" \
--gin_param="init_checkpoint = 'gs://t5-data/pretrained_models/${MODEL}/model.ckpt-${BASE_CKPT}'" \
--gin_file="dataset.gin" \
--gin_file="models/bi_v1.gin" \
--gin_file="gs://t5-data/pretrained_models/${MODEL}/operative_config.gin" \
--gin_param="utils.tpu_mesh_shape.model_parallelism = 1" \
--gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'" \
--gin_param="utils.run.train_dataset_fn = @t5.models.mesh_transformer.tsv_dataset_fn" \
--gin_param="tsv_dataset_fn.filename = '${GS_FOLDER}/query_doc_pairs.train.tsv'" \
--gin_file="learning_rate_schedules/constant_0_001.gin" \
--gin_param="run.train_steps = 1100000" \
--gin_param="tokens_per_batch = 65536" \
>> out.log_exp 2>&1 &

tail -100f out.log_exp
```

Training T5 base, large, and 3B take approximately 12, 48, and 160 hours overall, respectively, on a single TPU.

## Replication Log
21 changes: 21 additions & 0 deletions pygaggle/data/create_msmarco_t5_training_pairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
This script creates monoT5 input files for training,
Each line in the monoT5 input file follows the format:
f'Query: {query} Document: {document} Relevant:\t{label}\n')
"""
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--triples_train", type=str, required=True,
help="tsv file <query>, <positive_document>, <negative_document>")
parser.add_argument("--output_to_t5", type=str, required=True,
help="t5 train input file")
args = parser.parse_args()

with open(args.output_to_t5, 'w') as fout_t5:
for line_num, line in enumerate(tqdm(open(args.triples_train))):
query, positive_document, negative_document = line.strip().split('\t')
fout_t5.write(f'Query: {query} Document: {positive_document} Relevant:\ttrue\n')
fout_t5.write(f'Query: {query} Document: {negative_document} Relevant:\tfalse\n')
print('Done!')