Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support finetune #108

Merged
merged 24 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 72 additions & 6 deletions basics/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,70 @@ def __init__(self, *args, **kwargs):
def setup(self, stage):
self.phone_encoder = self.build_phone_encoder()
self.model = self.build_model()
# utils.load_warp(self)
if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None:
self.load_finetune_ckpt( self.load_pre_train_model())
self.print_arch()
self.build_losses()
self.train_dataset = self.dataset_cls(hparams['train_set_name'])
self.valid_dataset = self.dataset_cls(hparams['valid_set_name'])

def load_finetune_ckpt(
self, state_dict
):

adapt_shapes = hparams['finetune_strict_shapes']
if not adapt_shapes:
cur_model_state_dict = self.state_dict()
unmatched_keys = []
for key, param in state_dict.items():
if key in cur_model_state_dict:
new_param = cur_model_state_dict[key]
if new_param.shape != param.shape:
unmatched_keys.append(key)
print('| Unmatched keys: ', key, new_param.shape, param.shape)
for key in unmatched_keys:
del state_dict[key]
self.load_state_dict(state_dict, strict=False)

def load_pre_train_model(self):

pre_train_ckpt_path = hparams.get('finetune_ckpt_path')
blacklist = hparams.get('finetune_ignored_params')
# whitelist=hparams.get('pre_train_whitelist')
if blacklist is None:
blacklist = []
# if whitelist is None:
# raise RuntimeError("")

if pre_train_ckpt_path is not None:
ckpt = torch.load(pre_train_ckpt_path)
# if ckpt.get('category') is None:
# raise RuntimeError("")

if isinstance(self.model, CategorizedModule):
self.model.check_category(ckpt.get('category'))

state_dict = {}
for i in ckpt['state_dict']:
# if 'diffusion' in i:
# if i in rrrr:
# continue
skip = False
for b in blacklist:
if i.startswith(b):
skip = True
break

if skip:
continue

state_dict[i] = ckpt['state_dict'][i]
print(i)
return state_dict
else:
raise RuntimeError("")

@staticmethod
def build_phone_encoder():
phone_list = build_phoneme_list()
Expand Down Expand Up @@ -291,6 +350,11 @@ def on_test_end(self):
def start(cls):
pl.seed_everything(hparams['seed'], workers=True)
task = cls()

# if pre_train is not None:
# task.load_state_dict(pre_train,strict=False)
# print("load success-------------------------------------------------------------------")

work_dir = pathlib.Path(hparams['work_dir'])
trainer = pl.Trainer(
accelerator=hparams['pl_trainer_accelerator'],
Expand Down Expand Up @@ -378,16 +442,16 @@ def on_load_checkpoint(self, checkpoint):
from utils import simulate_lr_scheduler
if checkpoint.get('trainer_stage', '') == RunningStage.VALIDATING.value:
self.skip_immediate_validation = True

optimizer_args = hparams['optimizer_args']
scheduler_args = hparams['lr_scheduler_args']

if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args:
optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2'])

if checkpoint.get('optimizer_states', None):
opt_states = checkpoint['optimizer_states']
assert len(opt_states) == 1 # only support one optimizer
assert len(opt_states) == 1 # only support one optimizer
opt_state = opt_states[0]
for param_group in opt_state['param_groups']:
for k, v in optimizer_args.items():
Expand All @@ -397,13 +461,14 @@ def on_load_checkpoint(self, checkpoint):
rank_zero_info(f'| Overriding optimizer parameter {k} from checkpoint: {param_group[k]} -> {v}')
param_group[k] = v
if 'initial_lr' in param_group and param_group['initial_lr'] != optimizer_args['lr']:
rank_zero_info(f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}')
rank_zero_info(
f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}')
param_group['initial_lr'] = optimizer_args['lr']

if checkpoint.get('lr_schedulers', None):
assert checkpoint.get('optimizer_states', False)
schedulers = checkpoint['lr_schedulers']
assert len(schedulers) == 1 # only support one scheduler
assert len(schedulers) == 1 # only support one scheduler
scheduler = schedulers[0]
for k, v in scheduler_args.items():
if k in scheduler and scheduler[k] != v:
Expand All @@ -418,5 +483,6 @@ def on_load_checkpoint(self, checkpoint):
scheduler['_last_lr'] = new_lrs
for param_group, new_lr in zip(checkpoint['optimizer_states'][0]['param_groups'], new_lrs):
if param_group['lr'] != new_lr:
rank_zero_info(f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}')
rank_zero_info(
f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}')
param_group['lr'] = new_lr
10 changes: 10 additions & 0 deletions configs/acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,13 @@ max_updates: 320000
num_ckpt_keep: 5
permanent_ckpt_start: 200000
permanent_ckpt_interval: 40000


finetune_enabled: false
finetune_ckpt_path: null

finetune_ignored_params:
- model.fs2.encoder.embed_tokens
- model.fs2.txt_embed
- model.fs2.spk_embed
finetune_strict_shapes: true
11 changes: 11 additions & 0 deletions configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,14 @@ pl_trainer_precision: '32-true'
pl_trainer_num_nodes: 1
pl_trainer_strategy: 'auto'
ddp_backend: 'nccl' # choose from 'gloo', 'nccl', 'nccl_no_p2p'

###########
# finetune
###########

finetune_enabled: false
finetune_ckpt_path: null
finetune_ignored_params: []


finetune_strict_shapes: true
8 changes: 8 additions & 0 deletions configs/variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,11 @@ max_updates: 288000
num_ckpt_keep: 5
permanent_ckpt_start: 180000
permanent_ckpt_interval: 10000

finetune_enabled: false
finetune_ckpt_path: null
finetune_ignored_params:
- model.spk_embed
- model.fs2.txt_embed
- model.fs2.encoder.embed_tokens
finetune_strict_shapes: true
93 changes: 93 additions & 0 deletions docs/ConfigurationSchemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,98 @@ int

2048

### finetune_enabled

Whether to finetune from a pretrained model.

#### visibility

all

#### scope

training

#### customizability

normal

#### type

bool

#### default

False

### finetune_ckpt_path

Path to the pretrained model for finetuning.

#### visibility

all

#### scope

training

#### customizability

normal

#### type

str

#### default

null

### finetune_ignored_params

Prefixes of parameter key names in the state dict of the pretrained model that need to be dropped before finetuning.

#### visibility

all

#### scope

training

#### customizability

normal

#### type

list

### finetune_strict_shapes

Whether to raise error if the tensor shapes of any parameter of the pretrained model and the target model mismatch. If set to `False`, parameters with mismatching shapes will be skipped.

#### visibility

all

#### scope

training

#### customizability

normal

#### type

bool

#### default

True

### fmax

Maximum frequency of mel extraction.
Expand Down Expand Up @@ -3324,3 +3416,4 @@ int

2048


2 changes: 2 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import os

import sys
from pathlib import Path

Expand All @@ -22,6 +23,7 @@ def run_task():
pkg = ".".join(hparams["task_cls"].split(".")[:-1])
cls_name = hparams["task_cls"].split(".")[-1]
task_cls = getattr(importlib.import_module(pkg), cls_name)

task_cls.start()


Expand Down
16 changes: 14 additions & 2 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch.nn.functional as F

from basics.base_module import CategorizedModule
from utils.hparams import hparams
from utils.training_utils import get_latest_checkpoint_path


def tensors_to_scalars(metrics):
Expand Down Expand Up @@ -149,7 +151,8 @@ def filter_kwargs(dict_to_filter, kwarg_obj):

sig = inspect.signature(kwarg_obj)
filter_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD]
filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if filter_key in dict_to_filter}
filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if
filter_key in dict_to_filter}
return filtered_dict


Expand Down Expand Up @@ -208,6 +211,14 @@ def load_ckpt(
print(f'| load {shown_model_name} from \'{checkpoint_path}\'.')







# return load_pre_train_model()


def remove_padding(x, padding_idx=0):
if x is None:
return None
Expand Down Expand Up @@ -265,7 +276,8 @@ def simulate_lr_scheduler(optimizer_args, scheduler_args, last_epoch=-1, num_par
[{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)],
**optimizer_args
)
scheduler = build_object_from_config(scheduler_args['scheduler_cls'], optimizer, last_epoch=last_epoch, **scheduler_args)
scheduler = build_object_from_config(scheduler_args['scheduler_cls'], optimizer, last_epoch=last_epoch,
**scheduler_args)

if hasattr(scheduler, '_get_closed_form_lr'):
return scheduler._get_closed_form_lr()
Expand Down