Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[seq2seq] make it easier to run the scripts #7274

Merged
merged 4 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/seq2seq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ All finetuning bash scripts call finetune.py (or distillation.py) with reasonabl
To see all the possible command line options, run:

```bash
./finetune.sh --help # this calls python finetune.py --help
./finetune.py --help
```

### Finetuning Training Params
Expand Down Expand Up @@ -189,7 +189,7 @@ If 'translation' is in your task name, the computed metric will be BLEU. Otherwi
For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
```bash
export DATA_DIR=wmt_en_ro
python run_eval.py t5-base \
./run_eval.py t5-base \
$DATA_DIR/val.source t5_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \
Expand All @@ -203,7 +203,7 @@ python run_eval.py t5-base \
This command works for MBART, although the BLEU score is suspiciously low.
```bash
export DATA_DIR=wmt_en_ro
python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
./run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \
--task translation \
Expand All @@ -216,7 +216,7 @@ python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_gen
Summarization (xsum will be very similar):
```bash
export DATA_DIR=cnn_dm
python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
./run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path cnn_rouge.json \
--task summarization \
Expand All @@ -230,7 +230,7 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
### Multi-GPU Evalulation
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have
`{type_path}.source` and `{type_path}.target`. Run `python run_distributed_eval.py --help` for all clargs.
`{type_path}.source` and `{type_path}.target`. Run `./run_distributed_eval.py --help` for all clargs.

```bash
python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
Expand Down Expand Up @@ -363,11 +363,11 @@ This feature can only be used:
- with fairseq installed
- on 1 GPU
- without sortish sampler
- after calling `python save_len_file.py $tok $data_dir`
- after calling `./save_len_file.py $tok $data_dir`

For example,
```bash
python save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro
./save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro
./dynamic_bs_example.sh --max_tokens_per_batch=2000 --output_dir benchmark_dynamic_bs
```
splits `wmt_en_ro/train` into 11,197 uneven lengthed batches and can finish 1 epoch in 8 minutes on a v100.
Expand Down
2 changes: 2 additions & 0 deletions examples/seq2seq/convert_model_to_fp16.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

from typing import Union

import fire
Expand Down
2 changes: 2 additions & 0 deletions examples/seq2seq/convert_pl_checkpoint_to_hf.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

import os
from pathlib import Path
from typing import Dict, List
Expand Down
9 changes: 8 additions & 1 deletion examples/seq2seq/distillation.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#!/usr/bin/env python

import argparse
import gc
import os
import sys
import warnings
from pathlib import Path
from typing import List
Expand All @@ -13,7 +16,6 @@
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
from initialization_utils import copy_layers, init_student
from lightning_base import generic_train
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
from utils import (
Expand All @@ -27,6 +29,11 @@
)


# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import generic_train # noqa


class BartSummarizationDistiller(SummarizationModule):
"""Supports Bart, Pegasus and other models that inherit from Bart."""

Expand Down
2 changes: 2 additions & 0 deletions examples/seq2seq/download_wmt.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

from pathlib import Path

import fire
Expand Down
9 changes: 8 additions & 1 deletion examples/seq2seq/finetune.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#!/usr/bin/env python

import argparse
import glob
import logging
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
Expand All @@ -13,7 +16,6 @@
from torch.utils.data import DataLoader

from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
from utils import (
Expand All @@ -35,6 +37,11 @@
)


# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa


logger = logging.getLogger(__name__)


Expand Down
3 changes: 0 additions & 3 deletions examples/seq2seq/finetune.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"

# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
# run ./finetune.sh --help to see all the possible options
python finetune.py \
Expand Down
2 changes: 2 additions & 0 deletions examples/seq2seq/minify_dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

from pathlib import Path

import fire
Expand Down
2 changes: 2 additions & 0 deletions examples/seq2seq/pack_dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

"""Fill examples with bitext up to max_tokens without breaking up examples.
[['I went', 'yo fui'],
['to the store', 'a la tienda']
Expand Down
2 changes: 2 additions & 0 deletions examples/seq2seq/run_distributed_eval.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

import argparse
import shutil
import time
Expand Down
2 changes: 2 additions & 0 deletions examples/seq2seq/run_eval.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

import argparse
import datetime
import json
Expand Down
2 changes: 2 additions & 0 deletions examples/seq2seq/run_eval_search.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

import argparse
import itertools
import operator
Expand Down
9 changes: 3 additions & 6 deletions examples/seq2seq/save_len_file.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
#!/usr/bin/env python

import fire
from torch.utils.data import DataLoader
from tqdm import tqdm

from transformers import AutoTokenizer


try:
from .utils import Seq2SeqDataset, pickle_save
except ImportError:
from utils import Seq2SeqDataset, pickle_save
from utils import Seq2SeqDataset, pickle_save


def save_len_file(
Expand Down
2 changes: 2 additions & 0 deletions examples/seq2seq/test_bash_script.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

import argparse
import os
import sys
Expand Down
Binary file modified examples/seq2seq/test_data/wmt_en_ro/train.len
Binary file not shown.
Binary file modified examples/seq2seq/test_data/wmt_en_ro/val.len
Binary file not shown.
9 changes: 4 additions & 5 deletions examples/seq2seq/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
import pytest
from torch.utils.data import DataLoader

from pack_dataset import pack_data_dir
from save_len_file import save_len_file
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
from transformers import AutoTokenizer
from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import slow

from .pack_dataset import pack_data_dir
from .save_len_file import save_len_file
from .test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
from .utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset


BERT_BASE_CASED = "bert-base-cased"
Expand Down
10 changes: 2 additions & 8 deletions examples/seq2seq/test_fsmt_bleu_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,13 @@
# limitations under the License.

import io
import unittest


try:
from .utils import calculate_bleu
except ImportError:
from utils import calculate_bleu

import json
import unittest

from parameterized import parameterized
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device
from utils import calculate_bleu


filename = get_tests_dir() + "/test_data/fsmt/fsmt_val_data.json"
Expand Down