Skip to content

Commit

Permalink
Fix model avg (#1317)
Browse files Browse the repository at this point in the history
* fix a bug about the model_avg during finetuning by exchanging the order of loading pre-trained model and initializing avg model

* only match the exact module prefix
  • Loading branch information
marcoyang1998 authored Oct 18, 2023
1 parent 807816f commit 52c24df
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
11 changes: 9 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,12 @@ def load_model_params(
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)
Expand Down Expand Up @@ -1089,6 +1093,9 @@ def run(rank, world_size, args):
checkpoints = load_model_params(
ckpt=params.finetune_ckpt, model=model, init_modules=modules
)
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
else:
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
Expand Down
8 changes: 6 additions & 2 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,12 @@ def load_model_params(
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)
Expand Down

0 comments on commit 52c24df

Please sign in to comment.