Skip to content

Commit

Permalink
load checkpoint from specific path
Browse files Browse the repository at this point in the history
  • Loading branch information
yuekaizhang committed Mar 5, 2024
1 parent 73a7687 commit 50b575a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
3 changes: 0 additions & 3 deletions egs/multi_zh-hans/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare AISHELL-4"
if [ -e ../../aishell4/ASR/data/fbank/.fbank.done ]; then
cd data/fbank
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_dev) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_dev.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) .
cd ../..
else
Expand Down
2 changes: 1 addition & 1 deletion egs/multi_zh-hans/ASR/whisper/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def remove_long_utt(c: Cut):

test_sets = test_sets_cuts.keys()
test_dls = [
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))
for cuts_name in test_sets
]

Expand Down
23 changes: 22 additions & 1 deletion egs/multi_zh-hans/ASR/whisper/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
--model-name medium
"""


import os
import argparse
import copy
import logging
Expand Down Expand Up @@ -151,6 +151,15 @@ def get_parser():
""",
)

parser.add_argument(
"--pretrained-model-path",
type=str,
default=None,
help="""The path to the pretrained model if it is not None. Training will
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
""",
)

parser.add_argument(
"--base-lr", type=float, default=1e-5, help="The base learning rate."
)
Expand Down Expand Up @@ -617,6 +626,7 @@ def train_one_epoch(
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}")

try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
Expand Down Expand Up @@ -749,6 +759,16 @@ def run(rank, world_size, args):
replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, "cpu")
del model.alignment_heads

if params.pretrained_model_path:
checkpoint = torch.load(
params.pretrained_model_path, map_location="cpu"
)
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True)
else:
load_checkpoint(params.pretrained_model_path, model)

num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

Expand Down Expand Up @@ -900,6 +920,7 @@ def remove_short_and_long_utt(c: Cut):
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
else:
save_checkpoint(
params=params,
Expand Down

0 comments on commit 50b575a

Please sign in to comment.