Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

After fine-tune a 3-class dataset, how to load its fine-tuned weighted to update pre-trained ast model? #118

Open
jmren168 opened this issue Jan 18, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@jmren168
Copy link

Hi @YuanGongND,

In egs/audioset/inference.py, it tells I can load a pre-trained model as follows,

    # 2. load the best model and the weights
    checkpoint_path = args.model_path
    ast_mdl = ASTModel(label_dim=527, input_tdim=input_tdim, imagenet_pretrain=False, audioset_pretrain=False)
    print(f'[*INFO] load checkpoint: {checkpoint_path}')
    checkpoint = torch.load(checkpoint_path, map_location='cuda')
    audio_model = torch.nn.DataParallel(ast_mdl, device_ids=[0])
    audio_model.load_state_dict(checkpoint)

However, when I fine-tuend the pre-trained model on a 3-class datasets, then reload it again. Error message is,

size mismatch for module.v.pos_embed: copying a param with shape torch.Size([1, 602, 768]) from checkpoint, the shape in current model is torch.Size([1, 362, 768]).

My python script

    input_tdim=312
    class_num=3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_path = '/home/...'/
    
    sd = torch.load(model_path, map_location=device)
    model = ASTModel(label_dim=class_num, fstride=10, tstride=10, input_fdim=128, input_tdim=input_tdim, audioset_pretrain=True, model_size='base384',verbose=False)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(sd)

Do I miss sth here? Any suggestions are appreciated.

@YuanGongND YuanGongND added the bug Something isn't working label Jan 18, 2024
@YuanGongND
Copy link
Owner

it seems it is due to the inconsistent input_tdim in training and inference, could you share the training script (in particular, what is the input_tdim?). Thanks!

@jmren168
Copy link
Author

Thanks for the reply. Here's the training script:

set=full
imagenetpretrain=True
if [ $set == balanced ]
then
  bal=none
  lr=5e-5
  epoch=25
  #tr_data=/data/sls/scratch/yuangong/aed-pc/src/enhance_label/datafiles_local/balanced_train_data_type1_2_mean.json
  tr_data=./data/datafiles/train_data.json
  lrscheduler_start=10
  lrscheduler_step=5
  lrscheduler_decay=0.5
  wa_start=6
  wa_end=25
else
  bal=bal
  lr=1e-5
  epoch=15 #5
  tr_data=./data/datafiles/train_data.json
  lrscheduler_start=4 #2
  lrscheduler_step=1 #1
  lrscheduler_decay=0.25 #0.5
  wa_start=1
  wa_end=15 #5
fi
#te_data=/data/sls/scratch/yuangong/audioset/datafiles/eval_data.json
te_data=./data/datafiles/valid_data.json
freqm=48
timem=62 # 192
mixup=0
# corresponding to overlap of 6 for 16*16 patches
fstride=10
tstride=10
batch_size=4 # 12

dataset_mean=-4.2677393
dataset_std=4.5689974
audio_length=512 #1024
noise=False

@YuanGongND
Copy link
Owner

If you set audio_length=512 in training, then in inference, shouldn't the input_tdim=312 be 512?

@jmren168
Copy link
Author

It works, and thanks again.

BTW, when I loaded fine-tuned weights to update audioset pretrained model, do I set audioset_pretrain=True or audioset_pretrain=False?

model = ASTModel(label_dim=class_num, fstride=10, tstride=10, input_fdim=128, input_tdim=input_tdim, **audioset_pretrain=True**, model_size='base384',verbose=False)

@YuanGongND
Copy link
Owner

YuanGongND commented Jan 19, 2024

I guess it doesn't matter.

You can check by

model.load_state_dict(sd, strict=True), so it ensures the new weight fully covers all parameters (so which initial model does not matter).

-Yuan

@jmren168
Copy link
Author

Just setting strict=True forces new weights are loaded. Thanks for the reply.

@YuanGongND
Copy link
Owner

thanks for letting me know.

  • Just to clarify, strict=True itself does not change the loading behavior, but it will throw an error if the model parameter and the checkpoint are mismatched. If you don't see an error, that means all parameters are loaded from the checkpoint.

-Yuan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants