diff --git a/README.md b/README.md index f927491b..f9a0cce5 100644 --- a/README.md +++ b/README.md @@ -90,3 +90,4 @@ The following documents describe how to use Pygaggle on various IR test collecti + [Experiments on MS MARCO Document Retrieval](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-document.md) + [Experiments on MS MARCO Passage Retrieval - Dev Subset](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-passage-subset.md) + [Experiments on MS MARCO Passage Retrieval - Entire Dev Set](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-passage-entire.md) ++ [Experiments on MS MARCO Passage Retrieval - with TPU](https://github.com/castorini/pygaggle/blob/master/docs/experiments-monot5-tpu.md) diff --git a/docs/experiments-monot5-tpu.md b/docs/experiments-monot5-tpu.md index f726a046..b1eed929 100644 --- a/docs/experiments-monot5-tpu.md +++ b/docs/experiments-monot5-tpu.md @@ -211,4 +211,63 @@ You should see the same result. If you were able to replicate these results, please submit a PR adding to the replication log! Please mention in your PR if you find any difference! +## Train monoT5 + +First, download the MS MARCO train triples: +``` +wget https://storage.googleapis.com/duobert_git/triples.train.small.tar.gz +tar -xvf triples.train.small.tar.gz +rm triples.train.small.tar.gz +``` + +Then convert the train triples to t5 input format: +``` +python -m pygaggle.data.create_msmarco_t5_training_pairs --triples_train triples.train.small.tsv --output_to_t5 query_doc_pairs.train.tsv +``` + +Next, copy the input file to Google Storage. TPU training will read data directly from `gs` +``` +gsutil cp query_doc_pairs.train.tsv ${GS_FOLDER}/query_doc_pairs.train.tsv +``` + +Recall the environment variables +``` +export MODEL= +export GS_FOLDER= +export PROJECT_NAME= +export TPU_NAME= +export BASE_CKPT= +``` + +Copy pre-trained checkpoint to our target model +``` +echo "model_checkpoint_path: \"model.ckpt-${BASE_CKPT}\"" > checkpoint +gsutil cp checkpoint ${GS_FOLDER} +gsutil cp gs://t5-data/pretrained_models/${MODEL}/model.ckpt-${BASE_CKPT}* ${GS_FOLDER} +``` + +``` +nohup t5_mesh_transformer \ + --tpu="${TPU_NAME}" \ + --gcp_project="${PROJECT_NAME}" \ + --tpu_zone="europe-west4-a" \ + --model_dir="${GS_FOLDER}" \ + --gin_param="init_checkpoint = 'gs://t5-data/pretrained_models/${MODEL}/model.ckpt-${BASE_CKPT}'" \ + --gin_file="dataset.gin" \ + --gin_file="models/bi_v1.gin" \ + --gin_file="gs://t5-data/pretrained_models/${MODEL}/operative_config.gin" \ + --gin_param="utils.tpu_mesh_shape.model_parallelism = 1" \ + --gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'" \ + --gin_param="utils.run.train_dataset_fn = @t5.models.mesh_transformer.tsv_dataset_fn" \ + --gin_param="tsv_dataset_fn.filename = '${GS_FOLDER}/query_doc_pairs.train.tsv'" \ + --gin_file="learning_rate_schedules/constant_0_001.gin" \ + --gin_param="run.train_steps = 1100000" \ + --gin_param="tokens_per_batch = 65536" \ + >> out.log_exp 2>&1 & + +tail -100f out.log_exp +``` + +Training T5 base, large, and 3B take approximately 12, 48, and 160 hours overall, respectively, on a single TPU. + ## Replication Log diff --git a/pygaggle/data/create_msmarco_t5_training_pairs.py b/pygaggle/data/create_msmarco_t5_training_pairs.py new file mode 100644 index 00000000..186b5798 --- /dev/null +++ b/pygaggle/data/create_msmarco_t5_training_pairs.py @@ -0,0 +1,21 @@ +""" +This script creates monoT5 input files for training, +Each line in the monoT5 input file follows the format: + f'Query: {query} Document: {document} Relevant:\t{label}\n') +""" +import argparse +from tqdm import tqdm + +parser = argparse.ArgumentParser() +parser.add_argument("--triples_train", type=str, required=True, + help="tsv file , , ") +parser.add_argument("--output_to_t5", type=str, required=True, + help="t5 train input file") +args = parser.parse_args() + +with open(args.output_to_t5, 'w') as fout_t5: + for line_num, line in enumerate(tqdm(open(args.triples_train))): + query, positive_document, negative_document = line.strip().split('\t') + fout_t5.write(f'Query: {query} Document: {positive_document} Relevant:\ttrue\n') + fout_t5.write(f'Query: {query} Document: {negative_document} Relevant:\tfalse\n') +print('Done!')