diff --git a/examples/seq2seq/test_bash_script.py b/examples/seq2seq/test_bash_script.py index 71861ef4dbc6a3..7a7cd0806794a8 100644 --- a/examples/seq2seq/test_bash_script.py +++ b/examples/seq2seq/test_bash_script.py @@ -3,92 +3,107 @@ import argparse import os import sys -from pathlib import Path from unittest.mock import patch -import pytest import pytorch_lightning as pl import timeout_decorator import torch from distillation import BartSummarizationDistiller, distill_main from finetune import SummarizationModule, main -from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY -from transformers import BartForConditionalGeneration, MarianMTModel -from transformers.testing_utils import TestCasePlus, slow +from transformers import MarianMTModel +from transformers.file_utils import cached_path +from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow from utils import load_json -MODEL_NAME = MBART_TINY -MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" +MARIAN_MODEL = "sshleifer/mar_enro_6_3_student" -class TestAll(TestCasePlus): +class TestMbartCc25Enro(TestCasePlus): + def setUp(self): + super().setUp() + + data_cached = cached_path( + "https://cdn-datasets.huggingface.co/translation/wmt_en_ro-tr40k-va0.5k-te0.5k.tar.gz", + extract_compressed_file=True, + ) + self.data_dir = f"{data_cached}/wmt_en_ro-tr40k-va0.5k-te0.5k" + @slow - @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") + @require_torch_gpu def test_model_download(self): """This warms up the cache so that we can time the next test without including download time, which varies between machines.""" - BartForConditionalGeneration.from_pretrained(MODEL_NAME) MarianMTModel.from_pretrained(MARIAN_MODEL) - @timeout_decorator.timeout(120) + # @timeout_decorator.timeout(1200) @slow - @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") + @require_torch_gpu def test_train_mbart_cc25_enro_script(self): - data_dir = "examples/seq2seq/test_data/wmt_en_ro" env_vars_to_replace = { - "--fp16_opt_level=O1": "", - "$MAX_LEN": 128, - "$BS": 4, + "$MAX_LEN": 64, + "$BS": 64, "$GAS": 1, - "$ENRO_DIR": data_dir, - "facebook/mbart-large-cc25": MODEL_NAME, - # Download is 120MB in previous test. - "val_check_interval=0.25": "val_check_interval=1.0", + "$ENRO_DIR": self.data_dir, + "facebook/mbart-large-cc25": MARIAN_MODEL, + # "val_check_interval=0.25": "val_check_interval=1.0", + "--learning_rate=3e-5": "--learning_rate 3e-4", + "--num_train_epochs 6": "--num_train_epochs 1", } # Clean up bash script - bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() + bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") for k, v in env_vars_to_replace.items(): bash_script = bash_script.replace(k, str(v)) output_dir = self.get_auto_remove_tmp_dir() - bash_script = bash_script.replace("--fp16 ", "") - testargs = ( - ["finetune.py"] - + bash_script.split() - + [ - f"--output_dir={output_dir}", - "--gpus=1", - "--learning_rate=3e-1", - "--warmup_steps=0", - "--val_check_interval=1.0", - "--tokenizer_name=facebook/mbart-large-en-ro", - ] - ) + # bash_script = bash_script.replace("--fp16 ", "") + args = f""" + --output_dir {output_dir} + --tokenizer_name Helsinki-NLP/opus-mt-en-ro + --sortish_sampler + --do_predict + --gpus 1 + --freeze_encoder + --n_train 40000 + --n_val 500 + --n_test 500 + --fp16_opt_level O1 + --num_sanity_val_steps 0 + --eval_beams 2 + """.split() + # XXX: args.gpus > 1 : handle multigpu in the future + + testargs = ["finetune.py"] + bash_script.split() + args with patch.object(sys, "argv", testargs): parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) args = parser.parse_args() - args.do_predict = False - # assert args.gpus == gpus THIS BREAKS for multigpu model = main(args) # Check metrics metrics = load_json(model.metrics_save_path) first_step_stats = metrics["val"][0] last_step_stats = metrics["val"][-1] - assert ( - len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 - ) # +1 accounts for val_sanity_check + self.assertEqual(len(metrics["val"]), (args.max_epochs / args.val_check_interval)) + assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) - assert last_step_stats["val_avg_gen_time"] >= 0.01 + self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01) + # model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?) + self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0) - assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing - assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. - assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) + # test learning requirements: + + # 1. BLEU improves over the course of training by more than 2 pts + self.assertGreater(last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"], 2) + + # 2. BLEU finishes above 17 + self.assertGreater(last_step_stats["val_avg_bleu"], 17) + + # 3. test BLEU and val BLEU within ~1.1 pt. + self.assertLess(abs(metrics["val"][-1]["val_avg_bleu"] - metrics["test"][-1]["test_avg_bleu"]), 1.1) # check lightning ckpt can be loaded and has a reasonable statedict contents = os.listdir(output_dir) @@ -107,11 +122,13 @@ def test_train_mbart_cc25_enro_script(self): # assert len(metrics["val"]) == desired_n_evals assert len(metrics["test"]) == 1 + +class TestDistilMarianNoTeacher(TestCasePlus): @timeout_decorator.timeout(600) @slow - @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") + @require_torch_gpu def test_opus_mt_distill_script(self): - data_dir = "examples/seq2seq/test_data/wmt_en_ro" + data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro" env_vars_to_replace = { "--fp16_opt_level=O1": "", "$MAX_LEN": 128, @@ -124,7 +141,7 @@ def test_opus_mt_distill_script(self): # Clean up bash script bash_script = ( - Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip() + (self.test_file_dir / "distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip() ) bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") bash_script = bash_script.replace("--fp16 ", " ")