From 67cadf087170205824d00453042294ba26f49991 Mon Sep 17 00:00:00 2001 From: Yifan Peng Date: Fri, 11 Feb 2022 22:53:05 -0500 Subject: [PATCH 1/2] modify model averaging --- espnet2/main_funcs/average_nbest_models.py | 18 +++++++++++++----- espnet2/train/trainer.py | 1 + 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/espnet2/main_funcs/average_nbest_models.py b/espnet2/main_funcs/average_nbest_models.py index e025238e80e..a7edc07b68a 100644 --- a/espnet2/main_funcs/average_nbest_models.py +++ b/espnet2/main_funcs/average_nbest_models.py @@ -1,5 +1,6 @@ import logging from pathlib import Path +from typing import Optional from typing import Sequence from typing import Union import warnings @@ -17,6 +18,7 @@ def average_nbest_models( reporter: Reporter, best_model_criterion: Sequence[Sequence[str]], nbest: Union[Collection[int], int], + suffix: Optional[str] = None, ) -> None: """Generate averaged model from n-best models @@ -25,7 +27,8 @@ def average_nbest_models( reporter: Reporter instance best_model_criterion: Give criterions to decide the best model. e.g. [("valid", "loss", "min"), ("train", "acc", "max")] - nbest: + nbest: Number of best model files to be averaged + suffix: A suffix added to the averaged model file name """ assert check_argument_types() if isinstance(nbest, int): @@ -35,6 +38,11 @@ def average_nbest_models( if len(nbests) == 0: warnings.warn("At least 1 nbest values are required") nbests = [1] + if suffix is not None: + suffix = suffix + "." + else: + suffix = "" + # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] nbest_epochs = [ (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) @@ -55,12 +63,12 @@ def average_nbest_models( # The averaged model is same as the best model e, _ = epoch_and_values[0] op = output_dir / f"{e}epoch.pth" - sym_op = output_dir / f"{ph}.{cr}.ave_1best.pth" + sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pth" if sym_op.is_symlink() or sym_op.exists(): sym_op.unlink() sym_op.symlink_to(op.name) else: - op = output_dir / f"{ph}.{cr}.ave_{n}best.pth" + op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pth" logging.info( f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' ) @@ -96,8 +104,8 @@ def average_nbest_models( torch.save(avg, op) # 3. *.*.ave.pth is a symlink to the max ave model - op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.pth" - sym_op = output_dir / f"{ph}.{cr}.ave.pth" + op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pth" + sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pth" if sym_op.is_symlink() or sym_op.exists(): sym_op.unlink() sym_op.symlink_to(op.name) diff --git a/espnet2/train/trainer.py b/espnet2/train/trainer.py index c2dd3ea4ee8..766651ddbaa 100644 --- a/espnet2/train/trainer.py +++ b/espnet2/train/trainer.py @@ -424,6 +424,7 @@ def run( output_dir=output_dir, best_model_criterion=trainer_options.best_model_criterion, nbest=keep_nbest_models, + suffix=f"till{iepoch}epoch", ) for e in range(1, iepoch): From 8e1448ac4f6618c05d94d2ee38b1b0c0ec378723 Mon Sep 17 00:00:00 2001 From: Yifan Peng Date: Fri, 11 Feb 2022 23:04:37 -0500 Subject: [PATCH 2/2] apply black --- espnet2/main_funcs/average_nbest_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/espnet2/main_funcs/average_nbest_models.py b/espnet2/main_funcs/average_nbest_models.py index a7edc07b68a..4c278e23823 100644 --- a/espnet2/main_funcs/average_nbest_models.py +++ b/espnet2/main_funcs/average_nbest_models.py @@ -42,7 +42,7 @@ def average_nbest_models( suffix = suffix + "." else: suffix = "" - + # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] nbest_epochs = [ (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])