From 8d501ae47bfcba2cb2bc5c004d0bf4f6dce7deff Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 14:11:20 +0800 Subject: [PATCH 01/23] support pre_train model --- scripts/train.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/scripts/train.py b/scripts/train.py index 0ce341ca..ecab8688 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,7 +1,11 @@ import importlib import os +import pathlib import sys from pathlib import Path +import torch + +from utils.training_utils import get_latest_checkpoint_path root_dir = Path(__file__).parent.parent.resolve() os.environ['PYTHONPATH'] = str(root_dir) @@ -15,6 +19,39 @@ if hparams['ddp_backend'] == 'nccl_no_p2p': print("Disabling NCCL P2P") os.environ['NCCL_P2P_DISABLE'] = '1' +def load_pre_train_model(model,type:str='part'): + r""" load pre model + + Args: + model : + type (str): part mean loading wavenet full mean loading fs2 and wavenet + + """ + pre_train_ckpt_path=hparams.get('pre_train_path') + if pre_train_ckpt_path is not None: + ckpt=torch.load(pre_train_ckpt_path) + if type == "part": + state_dict = {} + for i in ckpt['state_dict']: + if 'diffusion' in i: + print(i) + state_dict[i] = ckpt['state_dict'][i] + model.load_state_dict(state_dict,strict=False) + elif type == "full": + ... + + +def load_warp(model): + if get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is not None: + pass + return None + type=hparams.get('pre_train_load_type') + if type is None: + type="part" + load_pre_train_model(model=model,type=type) + + + def run_task(): @@ -22,6 +59,8 @@ def run_task(): pkg = ".".join(hparams["task_cls"].split(".")[:-1]) cls_name = hparams["task_cls"].split(".")[-1] task_cls = getattr(importlib.import_module(pkg), cls_name) + load_warp(task_cls) + task_cls.start() From 22d02756540f7d7c417fee7ff83a51f3c9cf7f55 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 16:29:25 +0800 Subject: [PATCH 02/23] support pre_train model --- basics/base_task.py | 6 +++++- scripts/train.py | 46 ++++++++++++++++++++++++++--------------- utils/training_utils.py | 15 ++++++++++++++ 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index faf18812..afbeebb5 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -288,9 +288,13 @@ def on_test_end(self): ########### @classmethod - def start(cls): + def start(cls,pre_train=None): pl.seed_everything(hparams['seed'], workers=True) task = cls() + if pre_train is not None: + task.load_state_dict(pre_train,strict=False) + print("load success-------------------------------------------------------------------") + work_dir = pathlib.Path(hparams['work_dir']) trainer = pl.Trainer( accelerator=hparams['pl_trainer_accelerator'], diff --git a/scripts/train.py b/scripts/train.py index ecab8688..10b20f01 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -5,50 +5,62 @@ from pathlib import Path import torch -from utils.training_utils import get_latest_checkpoint_path + root_dir = Path(__file__).parent.parent.resolve() os.environ['PYTHONPATH'] = str(root_dir) sys.path.insert(0, str(root_dir)) os.environ['TORCH_CUDNN_V8_API_ENABLED'] = '1' # Prevent unacceptable slowdowns when using 16 precision - +from utils.training_utils import get_latest_checkpoint_path,remove_map from utils.hparams import set_hparams, hparams set_hparams() if hparams['ddp_backend'] == 'nccl_no_p2p': print("Disabling NCCL P2P") os.environ['NCCL_P2P_DISABLE'] = '1' -def load_pre_train_model(model,type:str='part'): +def load_pre_train_model(type='wavenet'): r""" load pre model Args: model : - type (str): part mean loading wavenet full mean loading fs2 and wavenet + type (str): wavenet mean loading wavenet full mean loading fs2 and wavenet """ pre_train_ckpt_path=hparams.get('pre_train_path') if pre_train_ckpt_path is not None: ckpt=torch.load(pre_train_ckpt_path) - if type == "part": - state_dict = {} - for i in ckpt['state_dict']: - if 'diffusion' in i: - print(i) - state_dict[i] = ckpt['state_dict'][i] - model.load_state_dict(state_dict,strict=False) - elif type == "full": + if ckpt.get('category') is None: + raise RuntimeError("") + if isinstance(type, str): + if type == "wavenet": + state_dict = {} + for i in ckpt['state_dict']: + if i in remove_map['base']: + continue + if 'diffusion' in i: + + state_dict[i] = ckpt['state_dict'][i] + return state_dict + # model.load_state_dict(state_dict=state_dict,strict=False) + elif type == "full": + ... + elif isinstance(type, list): ... + else: + raise RuntimeError("") + else: + return None -def load_warp(model): +def load_warp(): if get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is not None: pass return None type=hparams.get('pre_train_load_type') if type is None: - type="part" - load_pre_train_model(model=model,type=type) + type="wavenet" + return load_pre_train_model(type=type) @@ -59,9 +71,9 @@ def run_task(): pkg = ".".join(hparams["task_cls"].split(".")[:-1]) cls_name = hparams["task_cls"].split(".")[-1] task_cls = getattr(importlib.import_module(pkg), cls_name) - load_warp(task_cls) + ckpt =load_warp() - task_cls.start() + task_cls.start(ckpt) if __name__ == '__main__': diff --git a/utils/training_utils.py b/utils/training_utils.py index 1563daca..21f4e81e 100644 --- a/utils/training_utils.py +++ b/utils/training_utils.py @@ -16,6 +16,21 @@ import utils from utils.hparams import hparams +remove_map = {'base': ['model.diffusion.alphas_cumprod', 'model.diffusion.alphas_cumprod_prev', + 'model.diffusion.sqrt_alphas_cumprod', + 'model.diffusion.sqrt_one_minus_alphas_cumprod', 'model.diffusion.log_one_minus_alphas_cumprod', + 'model.diffusion.sqrt_recip_alphas_cumprod', 'model.diffusion.sqrt_recipm1_alphas_cumprod', + 'model.diffusion.posterior_variance', + 'model.diffusion.posterior_log_variance_clipped', 'model.diffusion.posterior_mean_coef1', + 'model.diffusion.posterior_mean_coef2', + 'model.diffusion.spec_min', 'model.diffusion.spec_max'], + 'speed_emb': ['model.fs2.speed_embed.weight', 'model.fs2.speed_embed.bias', ], + 'key_shift_emb': ['model.fs2.key_shift_embed.bias', 'model.fs2.key_shift_embed.weight', ] + , 'pitch_emb': ['model.fs2.pitch_embed.bias', 'model.fs2.pitch_embed.weight', ], + 'token_emb': ['model.fs2.encoder.embed_tokens.weight', 'model.fs2.txt_embed.weight'], + 'dur_emb': ['model.fs2.dur_embed.bias', 'model.fs2.dur_embed.weight']} + + # ==========LR schedulers========== From 73e254b345fd4569b11281775255c171a4f33f1f Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 19:20:09 +0800 Subject: [PATCH 03/23] support pre_train model --- basics/base_task.py | 9 +++--- configs/base.yaml | 9 ++++++ scripts/train.py | 51 +++--------------------------- utils/__init__.py | 69 +++++++++++++++++++++++++++++++++++++++++ utils/training_utils.py | 15 --------- 5 files changed, 88 insertions(+), 65 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index afbeebb5..ba2b5161 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -288,12 +288,13 @@ def on_test_end(self): ########### @classmethod - def start(cls,pre_train=None): + def start(cls): pl.seed_everything(hparams['seed'], workers=True) task = cls() - if pre_train is not None: - task.load_state_dict(pre_train,strict=False) - print("load success-------------------------------------------------------------------") + utils.load_warp(task) + # if pre_train is not None: + # task.load_state_dict(pre_train,strict=False) + # print("load success-------------------------------------------------------------------") work_dir = pathlib.Path(hparams['work_dir']) trainer = pl.Trainer( diff --git a/configs/base.yaml b/configs/base.yaml index 7d9f3732..efb7333d 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -89,3 +89,12 @@ pl_trainer_precision: '32-true' pl_trainer_num_nodes: 1 pl_trainer_strategy: 'auto' ddp_backend: 'nccl' # choose from 'gloo', 'nccl', 'nccl_no_p2p' + +########### +# finetune +########### + +finetune_enable: false +finetune_ckpt_path: +finetune_params_blacklist: ['model.fs2.encoder.embed_tokens.weight','model.fs2.speed_embed','model.fs2.key_shift_embed','model.fs2.txt_embed.weight'] +finetune_adapt_shapes: false diff --git a/scripts/train.py b/scripts/train.py index 10b20f01..a45549a1 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,9 +1,9 @@ import importlib import os -import pathlib +# import pathlib import sys from pathlib import Path -import torch +# import torch @@ -12,55 +12,14 @@ sys.path.insert(0, str(root_dir)) os.environ['TORCH_CUDNN_V8_API_ENABLED'] = '1' # Prevent unacceptable slowdowns when using 16 precision -from utils.training_utils import get_latest_checkpoint_path,remove_map +# from utils.training_utils import get_latest_checkpoint_path from utils.hparams import set_hparams, hparams set_hparams() if hparams['ddp_backend'] == 'nccl_no_p2p': print("Disabling NCCL P2P") os.environ['NCCL_P2P_DISABLE'] = '1' -def load_pre_train_model(type='wavenet'): - r""" load pre model - Args: - model : - type (str): wavenet mean loading wavenet full mean loading fs2 and wavenet - - """ - pre_train_ckpt_path=hparams.get('pre_train_path') - if pre_train_ckpt_path is not None: - ckpt=torch.load(pre_train_ckpt_path) - if ckpt.get('category') is None: - raise RuntimeError("") - if isinstance(type, str): - if type == "wavenet": - state_dict = {} - for i in ckpt['state_dict']: - if i in remove_map['base']: - continue - if 'diffusion' in i: - - state_dict[i] = ckpt['state_dict'][i] - return state_dict - # model.load_state_dict(state_dict=state_dict,strict=False) - elif type == "full": - ... - elif isinstance(type, list): - ... - else: - raise RuntimeError("") - else: - return None - - -def load_warp(): - if get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is not None: - pass - return None - type=hparams.get('pre_train_load_type') - if type is None: - type="wavenet" - return load_pre_train_model(type=type) @@ -71,9 +30,9 @@ def run_task(): pkg = ".".join(hparams["task_cls"].split(".")[:-1]) cls_name = hparams["task_cls"].split(".")[-1] task_cls = getattr(importlib.import_module(pkg), cls_name) - ckpt =load_warp() - task_cls.start(ckpt) + + task_cls.start() if __name__ == '__main__': diff --git a/utils/__init__.py b/utils/__init__.py index 6ef9d7a6..a8bc1974 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -11,6 +11,8 @@ import torch.nn.functional as F from basics.base_module import CategorizedModule +from utils.hparams import hparams +from utils.training_utils import get_latest_checkpoint_path def tensors_to_scalars(metrics): @@ -209,6 +211,73 @@ def load_ckpt( shown_model_name = f'\'{key_in_ckpt}\'' print(f'| load {shown_model_name} from \'{checkpoint_path}\'.') +def load_finetune_ckpt( + model,state_dict +): + adapt_shapes=hparams['finetune_adapt_shapes'] + if adapt_shapes: + cur_model_state_dict = model.state_dict() + unmatched_keys = [] + for key, param in state_dict.items(): + if key in cur_model_state_dict: + new_param = cur_model_state_dict[key] + if new_param.shape != param.shape: + unmatched_keys.append(key) + print('| Unmatched keys: ', key, new_param.shape, param.shape) + for key in unmatched_keys: + del state_dict[key] + model.load_state_dict(state_dict, strict=False) + + + +def load_pre_train_model(): + + pre_train_ckpt_path=hparams.get('finetune_ckpt_path') + blacklist=hparams.get('finetune_params_blacklist') + # whitelist=hparams.get('pre_train_whitelist') + if blacklist is None: + blacklist=[] + # if whitelist is None: + # raise RuntimeError("") + + if pre_train_ckpt_path is not None: + ckpt=torch.load(pre_train_ckpt_path) + if ckpt.get('category') is None: + raise RuntimeError("") + state_dict={} + for i in ckpt['state_dict']: + # if 'diffusion' in i: + # if i in rrrr: + # continue + skip = 0 + for b in blacklist: + if b in i: + skip = 1 + continue + + if skip == 1: + continue + + + state_dict[i] = ckpt['state_dict'][i] + print(i) + return state_dict + else: + raise RuntimeError("") + + +def load_warp(modle): + if not hparams['finetune_enable']: + return None + if get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is not None: #check + pass + return None + + load_finetune_ckpt(modle,load_pre_train_model()) + + # return load_pre_train_model() + + def remove_padding(x, padding_idx=0): if x is None: diff --git a/utils/training_utils.py b/utils/training_utils.py index 21f4e81e..1563daca 100644 --- a/utils/training_utils.py +++ b/utils/training_utils.py @@ -16,21 +16,6 @@ import utils from utils.hparams import hparams -remove_map = {'base': ['model.diffusion.alphas_cumprod', 'model.diffusion.alphas_cumprod_prev', - 'model.diffusion.sqrt_alphas_cumprod', - 'model.diffusion.sqrt_one_minus_alphas_cumprod', 'model.diffusion.log_one_minus_alphas_cumprod', - 'model.diffusion.sqrt_recip_alphas_cumprod', 'model.diffusion.sqrt_recipm1_alphas_cumprod', - 'model.diffusion.posterior_variance', - 'model.diffusion.posterior_log_variance_clipped', 'model.diffusion.posterior_mean_coef1', - 'model.diffusion.posterior_mean_coef2', - 'model.diffusion.spec_min', 'model.diffusion.spec_max'], - 'speed_emb': ['model.fs2.speed_embed.weight', 'model.fs2.speed_embed.bias', ], - 'key_shift_emb': ['model.fs2.key_shift_embed.bias', 'model.fs2.key_shift_embed.weight', ] - , 'pitch_emb': ['model.fs2.pitch_embed.bias', 'model.fs2.pitch_embed.weight', ], - 'token_emb': ['model.fs2.encoder.embed_tokens.weight', 'model.fs2.txt_embed.weight'], - 'dur_emb': ['model.fs2.dur_embed.bias', 'model.fs2.dur_embed.weight']} - - # ==========LR schedulers========== From 6bcf1be8422a1ecc641e015395716a8801b3f785 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 21:07:24 +0800 Subject: [PATCH 04/23] support pre_train model --- basics/base_task.py | 3 ++- configs/base.yaml | 4 ++-- utils/__init__.py | 20 ++++++++++++-------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index ba2b5161..b834e1cd 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -87,6 +87,7 @@ def __init__(self, *args, **kwargs): def setup(self, stage): self.phone_encoder = self.build_phone_encoder() self.model = self.build_model() + utils.load_warp(self) self.print_arch() self.build_losses() self.train_dataset = self.dataset_cls(hparams['train_set_name']) @@ -291,7 +292,7 @@ def on_test_end(self): def start(cls): pl.seed_everything(hparams['seed'], workers=True) task = cls() - utils.load_warp(task) + # if pre_train is not None: # task.load_state_dict(pre_train,strict=False) # print("load success-------------------------------------------------------------------") diff --git a/configs/base.yaml b/configs/base.yaml index efb7333d..ed1130ff 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -96,5 +96,5 @@ ddp_backend: 'nccl' # choose from 'gloo', 'nccl', 'nccl_no_p2p' finetune_enable: false finetune_ckpt_path: -finetune_params_blacklist: ['model.fs2.encoder.embed_tokens.weight','model.fs2.speed_embed','model.fs2.key_shift_embed','model.fs2.txt_embed.weight'] -finetune_adapt_shapes: false +finetune_ignored_params: ['model.fs2.encoder.embed_tokens.weight','model.fs2.txt_embed.weight'] +finetune_strict_shapes: true diff --git a/utils/__init__.py b/utils/__init__.py index a8bc1974..788e6b0f 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -214,8 +214,8 @@ def load_ckpt( def load_finetune_ckpt( model,state_dict ): - adapt_shapes=hparams['finetune_adapt_shapes'] - if adapt_shapes: + adapt_shapes=hparams['finetune_strict_shapes'] + if not adapt_shapes: cur_model_state_dict = model.state_dict() unmatched_keys = [] for key, param in state_dict.items(): @@ -230,10 +230,10 @@ def load_finetune_ckpt( -def load_pre_train_model(): +def load_pre_train_model(model): pre_train_ckpt_path=hparams.get('finetune_ckpt_path') - blacklist=hparams.get('finetune_params_blacklist') + blacklist=hparams.get('finetune_ignored_params') # whitelist=hparams.get('pre_train_whitelist') if blacklist is None: blacklist=[] @@ -242,8 +242,12 @@ def load_pre_train_model(): if pre_train_ckpt_path is not None: ckpt=torch.load(pre_train_ckpt_path) - if ckpt.get('category') is None: - raise RuntimeError("") + # if ckpt.get('category') is None: + # raise RuntimeError("") + + if isinstance(model.model, CategorizedModule): + model.model.check_category(ckpt.get('category')) + state_dict={} for i in ckpt['state_dict']: # if 'diffusion' in i: @@ -266,14 +270,14 @@ def load_pre_train_model(): raise RuntimeError("") -def load_warp(modle): +def load_warp(model): if not hparams['finetune_enable']: return None if get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is not None: #check pass return None - load_finetune_ckpt(modle,load_pre_train_model()) + load_finetune_ckpt(model,load_pre_train_model(model)) # return load_pre_train_model() From 08526369e8df0eac8daa87b85bb69f7015027578 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 21:10:13 +0800 Subject: [PATCH 05/23] support pre_train model --- basics/base_task.py | 14 ++++++++------ scripts/train.py | 7 +------ utils/__init__.py | 33 ++++++++++++++++----------------- 3 files changed, 25 insertions(+), 29 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index b834e1cd..5460d30d 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -384,16 +384,16 @@ def on_load_checkpoint(self, checkpoint): from utils import simulate_lr_scheduler if checkpoint.get('trainer_stage', '') == RunningStage.VALIDATING.value: self.skip_immediate_validation = True - + optimizer_args = hparams['optimizer_args'] scheduler_args = hparams['lr_scheduler_args'] - + if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args: optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) if checkpoint.get('optimizer_states', None): opt_states = checkpoint['optimizer_states'] - assert len(opt_states) == 1 # only support one optimizer + assert len(opt_states) == 1 # only support one optimizer opt_state = opt_states[0] for param_group in opt_state['param_groups']: for k, v in optimizer_args.items(): @@ -403,13 +403,14 @@ def on_load_checkpoint(self, checkpoint): rank_zero_info(f'| Overriding optimizer parameter {k} from checkpoint: {param_group[k]} -> {v}') param_group[k] = v if 'initial_lr' in param_group and param_group['initial_lr'] != optimizer_args['lr']: - rank_zero_info(f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}') + rank_zero_info( + f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}') param_group['initial_lr'] = optimizer_args['lr'] if checkpoint.get('lr_schedulers', None): assert checkpoint.get('optimizer_states', False) schedulers = checkpoint['lr_schedulers'] - assert len(schedulers) == 1 # only support one scheduler + assert len(schedulers) == 1 # only support one scheduler scheduler = schedulers[0] for k, v in scheduler_args.items(): if k in scheduler and scheduler[k] != v: @@ -424,5 +425,6 @@ def on_load_checkpoint(self, checkpoint): scheduler['_last_lr'] = new_lrs for param_group, new_lr in zip(checkpoint['optimizer_states'][0]['param_groups'], new_lrs): if param_group['lr'] != new_lr: - rank_zero_info(f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}') + rank_zero_info( + f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}') param_group['lr'] = new_lr diff --git a/scripts/train.py b/scripts/train.py index a45549a1..50e36a81 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -3,8 +3,8 @@ # import pathlib import sys from pathlib import Path -# import torch +# import torch root_dir = Path(__file__).parent.parent.resolve() @@ -21,17 +21,12 @@ os.environ['NCCL_P2P_DISABLE'] = '1' - - - - def run_task(): assert hparams['task_cls'] != '' pkg = ".".join(hparams["task_cls"].split(".")[:-1]) cls_name = hparams["task_cls"].split(".")[-1] task_cls = getattr(importlib.import_module(pkg), cls_name) - task_cls.start() diff --git a/utils/__init__.py b/utils/__init__.py index 788e6b0f..2a8f673f 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from basics.base_module import CategorizedModule -from utils.hparams import hparams +from utils.hparams import hparams from utils.training_utils import get_latest_checkpoint_path @@ -151,7 +151,8 @@ def filter_kwargs(dict_to_filter, kwarg_obj): sig = inspect.signature(kwarg_obj) filter_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD] - filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if filter_key in dict_to_filter} + filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if + filter_key in dict_to_filter} return filtered_dict @@ -211,10 +212,11 @@ def load_ckpt( shown_model_name = f'\'{key_in_ckpt}\'' print(f'| load {shown_model_name} from \'{checkpoint_path}\'.') + def load_finetune_ckpt( - model,state_dict + model, state_dict ): - adapt_shapes=hparams['finetune_strict_shapes'] + adapt_shapes = hparams['finetune_strict_shapes'] if not adapt_shapes: cur_model_state_dict = model.state_dict() unmatched_keys = [] @@ -229,26 +231,24 @@ def load_finetune_ckpt( model.load_state_dict(state_dict, strict=False) - def load_pre_train_model(model): - - pre_train_ckpt_path=hparams.get('finetune_ckpt_path') - blacklist=hparams.get('finetune_ignored_params') + pre_train_ckpt_path = hparams.get('finetune_ckpt_path') + blacklist = hparams.get('finetune_ignored_params') # whitelist=hparams.get('pre_train_whitelist') - if blacklist is None: - blacklist=[] + if blacklist is None: + blacklist = [] # if whitelist is None: # raise RuntimeError("") if pre_train_ckpt_path is not None: - ckpt=torch.load(pre_train_ckpt_path) + ckpt = torch.load(pre_train_ckpt_path) # if ckpt.get('category') is None: # raise RuntimeError("") if isinstance(model.model, CategorizedModule): model.model.check_category(ckpt.get('category')) - state_dict={} + state_dict = {} for i in ckpt['state_dict']: # if 'diffusion' in i: # if i in rrrr: @@ -262,7 +262,6 @@ def load_pre_train_model(model): if skip == 1: continue - state_dict[i] = ckpt['state_dict'][i] print(i) return state_dict @@ -273,16 +272,15 @@ def load_pre_train_model(model): def load_warp(model): if not hparams['finetune_enable']: return None - if get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is not None: #check + if get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is not None: # check pass return None - load_finetune_ckpt(model,load_pre_train_model(model)) + load_finetune_ckpt(model, load_pre_train_model(model)) # return load_pre_train_model() - def remove_padding(x, padding_idx=0): if x is None: return None @@ -340,7 +338,8 @@ def simulate_lr_scheduler(optimizer_args, scheduler_args, last_epoch=-1, num_par [{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)], **optimizer_args ) - scheduler = build_object_from_config(scheduler_args['scheduler_cls'], optimizer, last_epoch=last_epoch, **scheduler_args) + scheduler = build_object_from_config(scheduler_args['scheduler_cls'], optimizer, last_epoch=last_epoch, + **scheduler_args) if hasattr(scheduler, '_get_closed_form_lr'): return scheduler._get_closed_form_lr() From edc5599e7e1e0b04876355b7884886053de81c67 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 21:24:48 +0800 Subject: [PATCH 06/23] support pre_train model --- basics/base_task.py | 58 ++++++++++++++++++++++++++++++++++++++++++ utils/__init__.py | 62 +-------------------------------------------- 2 files changed, 59 insertions(+), 61 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 5460d30d..f27fd0a6 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -88,11 +88,69 @@ def setup(self, stage): self.phone_encoder = self.build_phone_encoder() self.model = self.build_model() utils.load_warp(self) + if hparams['finetune_enable'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: + self.load_finetune_ckpt( self.load_pre_train_model()) self.print_arch() self.build_losses() self.train_dataset = self.dataset_cls(hparams['train_set_name']) self.valid_dataset = self.dataset_cls(hparams['valid_set_name']) + def load_finetune_ckpt( + self, state_dict + ): + model=self + adapt_shapes = hparams['finetune_strict_shapes'] + if not adapt_shapes: + cur_model_state_dict = model.state_dict() + unmatched_keys = [] + for key, param in state_dict.items(): + if key in cur_model_state_dict: + new_param = cur_model_state_dict[key] + if new_param.shape != param.shape: + unmatched_keys.append(key) + print('| Unmatched keys: ', key, new_param.shape, param.shape) + for key in unmatched_keys: + del state_dict[key] + model.load_state_dict(state_dict, strict=False) + + def load_pre_train_model(self): + model=self + pre_train_ckpt_path = hparams.get('finetune_ckpt_path') + blacklist = hparams.get('finetune_ignored_params') + # whitelist=hparams.get('pre_train_whitelist') + if blacklist is None: + blacklist = [] + # if whitelist is None: + # raise RuntimeError("") + + if pre_train_ckpt_path is not None: + ckpt = torch.load(pre_train_ckpt_path) + # if ckpt.get('category') is None: + # raise RuntimeError("") + + if isinstance(model.model, CategorizedModule): + model.model.check_category(ckpt.get('category')) + + state_dict = {} + for i in ckpt['state_dict']: + # if 'diffusion' in i: + # if i in rrrr: + # continue + skip = 0 + for b in blacklist: + if b in i: + skip = 1 + continue + + if skip == 1: + continue + + state_dict[i] = ckpt['state_dict'][i] + print(i) + return state_dict + else: + raise RuntimeError("") + @staticmethod def build_phone_encoder(): phone_list = build_phoneme_list() diff --git a/utils/__init__.py b/utils/__init__.py index 2a8f673f..407ccb44 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -213,70 +213,10 @@ def load_ckpt( print(f'| load {shown_model_name} from \'{checkpoint_path}\'.') -def load_finetune_ckpt( - model, state_dict -): - adapt_shapes = hparams['finetune_strict_shapes'] - if not adapt_shapes: - cur_model_state_dict = model.state_dict() - unmatched_keys = [] - for key, param in state_dict.items(): - if key in cur_model_state_dict: - new_param = cur_model_state_dict[key] - if new_param.shape != param.shape: - unmatched_keys.append(key) - print('| Unmatched keys: ', key, new_param.shape, param.shape) - for key in unmatched_keys: - del state_dict[key] - model.load_state_dict(state_dict, strict=False) - - -def load_pre_train_model(model): - pre_train_ckpt_path = hparams.get('finetune_ckpt_path') - blacklist = hparams.get('finetune_ignored_params') - # whitelist=hparams.get('pre_train_whitelist') - if blacklist is None: - blacklist = [] - # if whitelist is None: - # raise RuntimeError("") - - if pre_train_ckpt_path is not None: - ckpt = torch.load(pre_train_ckpt_path) - # if ckpt.get('category') is None: - # raise RuntimeError("") - - if isinstance(model.model, CategorizedModule): - model.model.check_category(ckpt.get('category')) - - state_dict = {} - for i in ckpt['state_dict']: - # if 'diffusion' in i: - # if i in rrrr: - # continue - skip = 0 - for b in blacklist: - if b in i: - skip = 1 - continue - - if skip == 1: - continue - - state_dict[i] = ckpt['state_dict'][i] - print(i) - return state_dict - else: - raise RuntimeError("") -def load_warp(model): - if not hparams['finetune_enable']: - return None - if get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is not None: # check - pass - return None - load_finetune_ckpt(model, load_pre_train_model(model)) + # return load_pre_train_model() From 0f65cbe6b40fa141d081590f1dee493295962e9e Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 21:27:04 +0800 Subject: [PATCH 07/23] support pre_train model --- basics/base_task.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index f27fd0a6..6d2680b6 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -87,7 +87,7 @@ def __init__(self, *args, **kwargs): def setup(self, stage): self.phone_encoder = self.build_phone_encoder() self.model = self.build_model() - utils.load_warp(self) + # utils.load_warp(self) if hparams['finetune_enable'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: self.load_finetune_ckpt( self.load_pre_train_model()) self.print_arch() @@ -98,10 +98,10 @@ def setup(self, stage): def load_finetune_ckpt( self, state_dict ): - model=self + adapt_shapes = hparams['finetune_strict_shapes'] if not adapt_shapes: - cur_model_state_dict = model.state_dict() + cur_model_state_dict = self.state_dict() unmatched_keys = [] for key, param in state_dict.items(): if key in cur_model_state_dict: @@ -111,10 +111,10 @@ def load_finetune_ckpt( print('| Unmatched keys: ', key, new_param.shape, param.shape) for key in unmatched_keys: del state_dict[key] - model.load_state_dict(state_dict, strict=False) + self.load_state_dict(state_dict, strict=False) def load_pre_train_model(self): - model=self + pre_train_ckpt_path = hparams.get('finetune_ckpt_path') blacklist = hparams.get('finetune_ignored_params') # whitelist=hparams.get('pre_train_whitelist') @@ -128,8 +128,8 @@ def load_pre_train_model(self): # if ckpt.get('category') is None: # raise RuntimeError("") - if isinstance(model.model, CategorizedModule): - model.model.check_category(ckpt.get('category')) + if isinstance(self.model, CategorizedModule): + self.model.check_category(ckpt.get('category')) state_dict = {} for i in ckpt['state_dict']: From 2a701bc08af6f25263bb0451fbd3cc03826874aa Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 21:35:06 +0800 Subject: [PATCH 08/23] support pre_train model --- configs/base.yaml | 7 +++++-- scripts/train.py | 7 ++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/configs/base.yaml b/configs/base.yaml index ed1130ff..5cb67a3d 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -95,6 +95,9 @@ ddp_backend: 'nccl' # choose from 'gloo', 'nccl', 'nccl_no_p2p' ########### finetune_enable: false -finetune_ckpt_path: -finetune_ignored_params: ['model.fs2.encoder.embed_tokens.weight','model.fs2.txt_embed.weight'] +finetune_ckpt_path: null +finetune_ignored_params: + - model.fs2.encoder.embed_tokens.weight + - model.fs2.txt_embed.weight + finetune_strict_shapes: true diff --git a/scripts/train.py b/scripts/train.py index 50e36a81..1df7b6bc 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,18 +1,15 @@ import importlib import os -# import pathlib + import sys from pathlib import Path -# import torch - - root_dir = Path(__file__).parent.parent.resolve() os.environ['PYTHONPATH'] = str(root_dir) sys.path.insert(0, str(root_dir)) os.environ['TORCH_CUDNN_V8_API_ENABLED'] = '1' # Prevent unacceptable slowdowns when using 16 precision -# from utils.training_utils import get_latest_checkpoint_path + from utils.hparams import set_hparams, hparams set_hparams() From 6a3f899445205fa9821b4aa0eee96adf4ffb1dbc Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 21:38:38 +0800 Subject: [PATCH 09/23] support pre_train model --- basics/base_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basics/base_task.py b/basics/base_task.py index 6d2680b6..9ab14e58 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -140,7 +140,7 @@ def load_pre_train_model(self): for b in blacklist: if b in i: skip = 1 - continue + break if skip == 1: continue From 8dbc87aaac4b028a11750ae3c9390ceea387799c Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 21:42:04 +0800 Subject: [PATCH 10/23] support pre_train model --- basics/base_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 9ab14e58..72cb0d48 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -145,9 +145,9 @@ def load_pre_train_model(self): if skip == 1: continue - state_dict[i] = ckpt['state_dict'][i] + del ckpt['state_dict'][i] print(i) - return state_dict + return state_dict['state_dict'] else: raise RuntimeError("") From 0fcb0dd59f93b5d25c033c193341e0d9043c6447 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 21:44:33 +0800 Subject: [PATCH 11/23] support pre_train model --- basics/base_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 72cb0d48..9ab14e58 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -145,9 +145,9 @@ def load_pre_train_model(self): if skip == 1: continue - del ckpt['state_dict'][i] + state_dict[i] = ckpt['state_dict'][i] print(i) - return state_dict['state_dict'] + return state_dict else: raise RuntimeError("") From faac4cea2952ba455d366fe38a8ce9bdc0aa721a Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 21:46:39 +0800 Subject: [PATCH 12/23] support pre_train model --- basics/base_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basics/base_task.py b/basics/base_task.py index 9ab14e58..9f9db9fb 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -138,7 +138,7 @@ def load_pre_train_model(self): # continue skip = 0 for b in blacklist: - if b in i: + if i.startswith(b): skip = 1 break From 6d7e5a89da9d5ce0b85a2bcabdfada574e6fbcca Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 7 Jul 2023 21:52:36 +0800 Subject: [PATCH 13/23] support pre_train model --- basics/base_task.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 9f9db9fb..433abbff 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -136,13 +136,13 @@ def load_pre_train_model(self): # if 'diffusion' in i: # if i in rrrr: # continue - skip = 0 + skip = False for b in blacklist: if i.startswith(b): - skip = 1 + skip = True break - if skip == 1: + if skip: continue state_dict[i] = ckpt['state_dict'][i] From 8fa61724c8041ef9f114a4227c38a7bf72563f9b Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Sat, 8 Jul 2023 10:38:04 +0800 Subject: [PATCH 14/23] support pre_train model add doc --- docs/ConfigurationSchemas.md | 82 ++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/docs/ConfigurationSchemas.md b/docs/ConfigurationSchemas.md index f4c2d195..c90c4b83 100644 --- a/docs/ConfigurationSchemas.md +++ b/docs/ConfigurationSchemas.md @@ -3322,3 +3322,85 @@ int 2048 + +### finetune_enable + +use pretrain model + +#### visibility + +acoustic + +#### type + +boolean + +#### default + +false + +#### constraints + +Must be true if use pretrain model + +### finetune_ckpt_path + +your pretrain path + +#### visibility + +acoustic + +#### type + +str + +#### default + +null + +#### constraints + +Must be a path + +### finetune_ignored_params + +the params you want to ignore in finetune + +#### visibility + +acoustic + +#### type + +list + +#### default + + - model.fs2.encoder.embed_tokens.weight + - model.fs2.txt_embed.weight + +#### constraints + +Must be a list + +### finetune_strict_shapes + +when you finetune model have some shapes mismatch the model ignored or error +default is error + +#### visibility + +acoustic + +#### type + +boolean + +#### default + + true + +#### constraints + +Must be a boolean \ No newline at end of file From 1656ca5ab33080f84582f883050ec57537db917e Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Sat, 8 Jul 2023 12:34:56 +0800 Subject: [PATCH 15/23] support pre_train model add doc --- docs/ConfigurationSchemas.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/ConfigurationSchemas.md b/docs/ConfigurationSchemas.md index c90c4b83..05ca36b9 100644 --- a/docs/ConfigurationSchemas.md +++ b/docs/ConfigurationSchemas.md @@ -3329,7 +3329,7 @@ use pretrain model #### visibility -acoustic +all #### type @@ -3349,7 +3349,7 @@ your pretrain path #### visibility -acoustic +all #### type @@ -3369,7 +3369,7 @@ the params you want to ignore in finetune #### visibility -acoustic +all #### type @@ -3391,7 +3391,7 @@ default is error #### visibility -acoustic +all #### type From 78cd2870df7982a6a59b06bb1d1b92a30581629c Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Sat, 8 Jul 2023 14:17:50 +0800 Subject: [PATCH 16/23] support pre_train model add doc --- configs/acoustic.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index a08022d5..5fc0916e 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -92,3 +92,9 @@ max_updates: 320000 num_ckpt_keep: 5 permanent_ckpt_start: 120000 permanent_ckpt_interval: 40000 + + +finetune_enable: true +finetune_ckpt_path: null +finetune_ignored_params: ['model.fs2.encoder.embed_tokens.weight','model.fs2.speed_embed','model.fs2.key_shift_embed'] +finetune_strict_shapes: true From ea9157ab230c1ba74ce764d5ed90f5452633d237 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Sat, 8 Jul 2023 14:20:47 +0800 Subject: [PATCH 17/23] support pre_train model add doc --- configs/acoustic.yaml | 2 +- configs/base.yaml | 5 ++--- configs/variance.yaml | 6 ++++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 5fc0916e..bb3db578 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -96,5 +96,5 @@ permanent_ckpt_interval: 40000 finetune_enable: true finetune_ckpt_path: null -finetune_ignored_params: ['model.fs2.encoder.embed_tokens.weight','model.fs2.speed_embed','model.fs2.key_shift_embed'] +finetune_ignored_params: ['model.fs2.encoder.embed_tokens.weight','model.fs2.txt_embed.weight','model.fs2.spk_embed'] finetune_strict_shapes: true diff --git a/configs/base.yaml b/configs/base.yaml index 5cb67a3d..7b4f2990 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -96,8 +96,7 @@ ddp_backend: 'nccl' # choose from 'gloo', 'nccl', 'nccl_no_p2p' finetune_enable: false finetune_ckpt_path: null -finetune_ignored_params: - - model.fs2.encoder.embed_tokens.weight - - model.fs2.txt_embed.weight +finetune_ignored_params: [] + finetune_strict_shapes: true diff --git a/configs/variance.yaml b/configs/variance.yaml index edf947e2..683feeb4 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -104,3 +104,9 @@ max_updates: 320000 num_ckpt_keep: 5 permanent_ckpt_start: 120000 permanent_ckpt_interval: 40000 + + +finetune_enable: true +finetune_ckpt_path: null +finetune_ignored_params: ['model.fs2.encoder.embed_tokens.weight','model.spk_embed',] +finetune_strict_shapes: true From 088be5a66045ad2f75be19ff19505241b3975f33 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Sat, 8 Jul 2023 14:21:08 +0800 Subject: [PATCH 18/23] support pre_train model add doc --- docs/ConfigurationSchemas.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/ConfigurationSchemas.md b/docs/ConfigurationSchemas.md index 05ca36b9..75c8ee0e 100644 --- a/docs/ConfigurationSchemas.md +++ b/docs/ConfigurationSchemas.md @@ -3377,8 +3377,7 @@ list #### default - - model.fs2.encoder.embed_tokens.weight - - model.fs2.txt_embed.weight +null #### constraints From d2162e5eaa52634b53aa9b6e9618f55e5303351e Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Sat, 8 Jul 2023 14:27:16 +0800 Subject: [PATCH 19/23] support pre_train model add doc --- configs/acoustic.yaml | 2 +- configs/variance.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index bb3db578..9754219a 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -94,7 +94,7 @@ permanent_ckpt_start: 120000 permanent_ckpt_interval: 40000 -finetune_enable: true +finetune_enable: false finetune_ckpt_path: null finetune_ignored_params: ['model.fs2.encoder.embed_tokens.weight','model.fs2.txt_embed.weight','model.fs2.spk_embed'] finetune_strict_shapes: true diff --git a/configs/variance.yaml b/configs/variance.yaml index 683feeb4..f6e73681 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -106,7 +106,7 @@ permanent_ckpt_start: 120000 permanent_ckpt_interval: 40000 -finetune_enable: true +finetune_enable: false finetune_ckpt_path: null finetune_ignored_params: ['model.fs2.encoder.embed_tokens.weight','model.spk_embed',] finetune_strict_shapes: true From 91a7b885e6499b6d4abb6dc7ac80eca11b94fc83 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Sat, 8 Jul 2023 14:32:23 +0800 Subject: [PATCH 20/23] support pre_train model add doc --- configs/variance.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/variance.yaml b/configs/variance.yaml index f6e73681..7ca5d08b 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -108,5 +108,5 @@ permanent_ckpt_interval: 40000 finetune_enable: false finetune_ckpt_path: null -finetune_ignored_params: ['model.fs2.encoder.embed_tokens.weight','model.spk_embed',] +finetune_ignored_params: ['model.spk_embed','model.fs2.txt_embed.weight','model.fs2.encoder.embed_tokens.weight'] finetune_strict_shapes: true From bccd9f4a5b59e67393538b869bce1f5133c07188 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Sat, 8 Jul 2023 14:36:39 +0800 Subject: [PATCH 21/23] support pre_train model add doc --- configs/acoustic.yaml | 6 +++++- configs/variance.yaml | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 9754219a..ed17dc40 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -96,5 +96,9 @@ permanent_ckpt_interval: 40000 finetune_enable: false finetune_ckpt_path: null -finetune_ignored_params: ['model.fs2.encoder.embed_tokens.weight','model.fs2.txt_embed.weight','model.fs2.spk_embed'] + +finetune_ignored_params: + - model.fs2.encoder.embed_tokens + - model.fs2.txt_embed + - model.fs2.spk_embed finetune_strict_shapes: true diff --git a/configs/variance.yaml b/configs/variance.yaml index 7ca5d08b..1e119cc0 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -108,5 +108,8 @@ permanent_ckpt_interval: 40000 finetune_enable: false finetune_ckpt_path: null -finetune_ignored_params: ['model.spk_embed','model.fs2.txt_embed.weight','model.fs2.encoder.embed_tokens.weight'] +finetune_ignored_params: + - model.spk_embed + - model.fs2.txt_embed + - model.fs2.encoder.embed_tokens finetune_strict_shapes: true From e7791065a9fb0c4916b52e323d60bff5cb0abc02 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Sat, 8 Jul 2023 15:21:23 +0800 Subject: [PATCH 22/23] support pre_train model add doc --- basics/base_task.py | 2 +- configs/acoustic.yaml | 2 +- configs/base.yaml | 2 +- configs/variance.yaml | 2 +- docs/ConfigurationSchemas.md | 162 ++++++++++++++++++----------------- 5 files changed, 86 insertions(+), 84 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 433abbff..a9f47159 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -88,7 +88,7 @@ def setup(self, stage): self.phone_encoder = self.build_phone_encoder() self.model = self.build_model() # utils.load_warp(self) - if hparams['finetune_enable'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: + if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: self.load_finetune_ckpt( self.load_pre_train_model()) self.print_arch() self.build_losses() diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index ed17dc40..42ead022 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -94,7 +94,7 @@ permanent_ckpt_start: 120000 permanent_ckpt_interval: 40000 -finetune_enable: false +finetune_enabled: false finetune_ckpt_path: null finetune_ignored_params: diff --git a/configs/base.yaml b/configs/base.yaml index 7b4f2990..bc570e46 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -94,7 +94,7 @@ ddp_backend: 'nccl' # choose from 'gloo', 'nccl', 'nccl_no_p2p' # finetune ########### -finetune_enable: false +finetune_enabled: false finetune_ckpt_path: null finetune_ignored_params: [] diff --git a/configs/variance.yaml b/configs/variance.yaml index 1e119cc0..1087fc7a 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -106,7 +106,7 @@ permanent_ckpt_start: 120000 permanent_ckpt_interval: 40000 -finetune_enable: false +finetune_enabled: false finetune_ckpt_path: null finetune_ignored_params: - model.spk_embed diff --git a/docs/ConfigurationSchemas.md b/docs/ConfigurationSchemas.md index 75c8ee0e..13175129 100644 --- a/docs/ConfigurationSchemas.md +++ b/docs/ConfigurationSchemas.md @@ -1304,6 +1304,88 @@ int 2048 +### finetune_enable + +use pretrain model + +#### visibility + +all + +#### type + +boolean + +#### default + +false + +#### constraints + +Must be true if use pretrain model + +### finetune_ckpt_path + +your pretrain path + +#### visibility + +all + +#### type + +str + +#### default + +null + +#### constraints + +Must be a path + +### finetune_ignored_params + +the params you want to ignore in finetune + +#### visibility + +all + +#### type + +list + +#### default + +null + +#### constraints + +Must be a list + +### finetune_strict_shapes + +when you finetune model have some shapes mismatch the model ignored or error +default is error + +#### visibility + +all + +#### type + +boolean + +#### default + + true + +#### constraints + +Must be a boolean + + ### fmax Maximum frequency of mel extraction. @@ -3323,83 +3405,3 @@ int 2048 -### finetune_enable - -use pretrain model - -#### visibility - -all - -#### type - -boolean - -#### default - -false - -#### constraints - -Must be true if use pretrain model - -### finetune_ckpt_path - -your pretrain path - -#### visibility - -all - -#### type - -str - -#### default - -null - -#### constraints - -Must be a path - -### finetune_ignored_params - -the params you want to ignore in finetune - -#### visibility - -all - -#### type - -list - -#### default - -null - -#### constraints - -Must be a list - -### finetune_strict_shapes - -when you finetune model have some shapes mismatch the model ignored or error -default is error - -#### visibility - -all - -#### type - -boolean - -#### default - - true - -#### constraints - -Must be a boolean \ No newline at end of file From ad9576e0676c8c10f079cc12bbcfc72fe8d3cc8b Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Mon, 17 Jul 2023 20:23:18 +0800 Subject: [PATCH 23/23] Update docs for finetuning --- docs/ConfigurationSchemas.md | 66 +++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/docs/ConfigurationSchemas.md b/docs/ConfigurationSchemas.md index 3edd92fa..417350a4 100644 --- a/docs/ConfigurationSchemas.md +++ b/docs/ConfigurationSchemas.md @@ -1306,34 +1306,46 @@ int 2048 -### finetune_enable +### finetune_enabled -use pretrain model +Whether to finetune from a pretrained model. #### visibility all -#### type +#### scope -boolean +training -#### default +#### customizability -false +normal -#### constraints +#### type + +bool -Must be true if use pretrain model +#### default + +False ### finetune_ckpt_path -your pretrain path +Path to the pretrained model for finetuning. #### visibility all +#### scope + +training + +#### customizability + +normal + #### type str @@ -1342,51 +1354,49 @@ str null -#### constraints - -Must be a path - ### finetune_ignored_params -the params you want to ignore in finetune +Prefixes of parameter key names in the state dict of the pretrained model that need to be dropped before finetuning. #### visibility all -#### type +#### scope -list +training -#### default +#### customizability -null +normal -#### constraints +#### type -Must be a list +list ### finetune_strict_shapes -when you finetune model have some shapes mismatch the model ignored or error -default is error +Whether to raise error if the tensor shapes of any parameter of the pretrained model and the target model mismatch. If set to `False`, parameters with mismatching shapes will be skipped. #### visibility all -#### type +#### scope -boolean +training -#### default +#### customizability - true +normal -#### constraints +#### type -Must be a boolean +bool + +#### default +True ### fmax