Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with loading model weights to reproduce the demo notebook #12

Open
bpiyush opened this issue Feb 16, 2023 · 2 comments
Open

Issues with loading model weights to reproduce the demo notebook #12

bpiyush opened this issue Feb 16, 2023 · 2 comments

Comments

@bpiyush
Copy link

bpiyush commented Feb 16, 2023

Hi! Great work!

I was trying to reproduce the demo in this notebook. While loading model weights from a pre-trained checkpoint using:

chkpt_dir = join(repo_path, "external/AudioMAE/checkpoints", "pretrained.pth")
assert os.path.exists(chkpt_dir), f"Checkpoint does not exist at {chkpt_dir}"

model = prepare_model(chkpt_dir, 'mae_vit_base_patch16')
print('Model loaded.')

I get the following warning message

_IncompatibleKeys(missing_keys=[], unexpected_keys=['decoder_blocks.8.attn.tau', 'decoder_blocks.8.attn.qkv.weight', 'decoder_blocks.8.attn.qkv.bias', 'decoder_blocks.8.attn.proj.weight', 'decoder_blocks.8.attn.proj.bias', 'decoder_blocks.8.attn.meta_mlp.fc1.weight', 'decoder_blocks.8.attn.meta_mlp.fc1.bias', 'decoder_blocks.8.attn.meta_mlp.fc2.weight', 'decoder_blocks.8.attn.meta_mlp.fc2.bias', 'decoder_blocks.8.norm1.weight', 'decoder_blocks.8.norm1.bias', 'decoder_blocks.8.mlp.fc1.weight', 'decoder_blocks.8.mlp.fc1.bias', 'decoder_blocks.8.mlp.fc2.weight', 'decoder_blocks.8.mlp.fc2.bias', 'decoder_blocks.8.norm2.weight', 'decoder_blocks.8.norm2.bias', 'decoder_blocks.9.attn.tau', 'decoder_blocks.9.attn.qkv.weight', 'decoder_blocks.9.attn.qkv.bias', 'decoder_blocks.9.attn.proj.weight', 'decoder_blocks.9.attn.proj.bias'.....

I believe it isn't loading the decoder weights correctly. Could you please help me fix this? @berniebear

Thanks!

@asheff794
Copy link

Hi @bpiyush, I am getting what appears to be the same error running the demo.py notebook. Did you ever find a solution? The output I'm getting looks like noise, so I think you're correct that the decoder weights aren't loading properly.

_IncompatibleKeys(missing_keys=[], unexpected_keys=['decoder_blocks.8.attn.tau', 'decoder_blocks.8.attn.qkv.weight', 'decoder_blocks.8.attn.qkv.bias', 'decoder_blocks.8.attn.proj.weight', 'decoder_blocks.8.attn.proj.bias', 'decoder_blocks.8.attn.meta_mlp.fc1.weight', 'decoder_blocks.8.attn.meta_mlp.fc1.bias', 'decoder_blocks.8.attn.meta_mlp.fc2.weight', 'decoder_blocks.8.attn.meta_mlp.fc2.bias', 'decoder_blocks.8.norm1.weight', 'decoder_blocks.8.norm1.bias', 'decoder_blocks.8.mlp.fc1.weight', 'decoder_blocks.8.mlp.fc1.bias', 'decoder_blocks.8.mlp.fc2.weight', 'decoder_blocks.8.mlp.fc2.bias', 'decoder_blocks.8.norm2.weight', 'decoder_blocks.8.norm2.bias', 'decoder_blocks.9.attn.tau', 'decoder_blocks.9.attn.qkv.weight', 'decoder_blocks.9.attn.qkv.bias', 'decoder_blocks.9.attn.proj.weight', 'decoder_blocks.9.attn.proj.bias', 'decoder_blocks.9.attn.meta_mlp.fc1.weight', 'decoder_blocks.9.attn.meta_mlp.fc1.bias', 'decoder_blocks.9.attn.meta_mlp.fc2.weight', 'decoder_blocks.9.attn.meta_mlp.fc2.bias', 'decoder_blocks.9.norm1.weight', 'decoder_blocks.9.norm1.bias', 'decoder_blocks.9.mlp.fc1.weight', 'decoder_blocks.9.mlp.fc1.bias', 'decoder_blocks.9.mlp.fc2.weight', 'decoder_blocks.9.mlp.fc2.bias', 'decoder_blocks.9.norm2.weight', 'decoder_blocks.9.norm2.bias', 'decoder_blocks.10.attn.tau', 'decoder_blocks.10.attn.qkv.weight', 'decoder_blocks.10.attn.qkv.bias', 'decoder_blocks.10.attn.proj.weight', 'decoder_blocks.10.attn.proj.bias', 'decoder_blocks.10.attn.meta_mlp.fc1.weight', 'decoder_blocks.10.attn.meta_mlp.fc1.bias', 'decoder_blocks.10.attn.meta_mlp.fc2.weight', 'decoder_blocks.10.attn.meta_mlp.fc2.bias', 'decoder_blocks.10.norm1.weight', 'decoder_blocks.10.norm1.bias', 'decoder_blocks.10.mlp.fc1.weight', 'decoder_blocks.10.mlp.fc1.bias', 'decoder_blocks.10.mlp.fc2.weight', 'decoder_blocks.10.mlp.fc2.bias', 'decoder_blocks.10.norm2.weight', 'decoder_blocks.10.norm2.bias', 'decoder_blocks.11.attn.tau', 'decoder_blocks.11.attn.qkv.weight', 'decoder_blocks.11.attn.qkv.bias', 'decoder_blocks.11.attn.proj.weight', 'decoder_blocks.11.attn.proj.bias', 'decoder_blocks.11.attn.meta_mlp.fc1.weight', 'decoder_blocks.11.attn.meta_mlp.fc1.bias', 'decoder_blocks.11.attn.meta_mlp.fc2.weight', 'decoder_blocks.11.attn.meta_mlp.fc2.bias', 'decoder_blocks.11.norm1.weight', 'decoder_blocks.11.norm1.bias', 'decoder_blocks.11.mlp.fc1.weight', 'decoder_blocks.11.mlp.fc1.bias', 'decoder_blocks.11.mlp.fc2.weight', 'decoder_blocks.11.mlp.fc2.bias', 'decoder_blocks.11.norm2.weight', 'decoder_blocks.11.norm2.bias', 'decoder_blocks.12.attn.tau', 'decoder_blocks.12.attn.qkv.weight', 'decoder_blocks.12.attn.qkv.bias', 'decoder_blocks.12.attn.proj.weight', 'decoder_blocks.12.attn.proj.bias', 'decoder_blocks.12.attn.meta_mlp.fc1.weight', 'decoder_blocks.12.attn.meta_mlp.fc1.bias', 'decoder_blocks.12.attn.meta_mlp.fc2.weight', 'decoder_blocks.12.attn.meta_mlp.fc2.bias', 'decoder_blocks.12.norm1.weight', 'decoder_blocks.12.norm1.bias', 'decoder_blocks.12.mlp.fc1.weight', 'decoder_blocks.12.mlp.fc1.bias', 'decoder_blocks.12.mlp.fc2.weight', 'decoder_blocks.12.mlp.fc2.bias', 'decoder_blocks.12.norm2.weight', 'decoder_blocks.12.norm2.bias', 'decoder_blocks.13.attn.tau', 'decoder_blocks.13.attn.qkv.weight', 'decoder_blocks.13.attn.qkv.bias', 'decoder_blocks.13.attn.proj.weight', 'decoder_blocks.13.attn.proj.bias', 'decoder_blocks.13.attn.meta_mlp.fc1.weight', 'decoder_blocks.13.attn.meta_mlp.fc1.bias', 'decoder_blocks.13.attn.meta_mlp.fc2.weight', 'decoder_blocks.13.attn.meta_mlp.fc2.bias', 'decoder_blocks.13.norm1.weight', 'decoder_blocks.13.norm1.bias', 'decoder_blocks.13.mlp.fc1.weight', 'decoder_blocks.13.mlp.fc1.bias', 'decoder_blocks.13.mlp.fc2.weight', 'decoder_blocks.13.mlp.fc2.bias', 'decoder_blocks.13.norm2.weight', 'decoder_blocks.13.norm2.bias', 'decoder_blocks.14.attn.tau', 'decoder_blocks.14.attn.qkv.weight', 'decoder_blocks.14.attn.qkv.bias', 'decoder_blocks.14.attn.proj.weight', 'decoder_blocks.14.attn.proj.bias', 'decoder_blocks.14.attn.meta_mlp.fc1.weight', 'decoder_blocks.14.attn.meta_mlp.fc1.bias', 'decoder_blocks.14.attn.meta_mlp.fc2.weight', 'decoder_blocks.14.attn.meta_mlp.fc2.bias', 'decoder_blocks.14.norm1.weight', 'decoder_blocks.14.norm1.bias', 'decoder_blocks.14.mlp.fc1.weight', 'decoder_blocks.14.mlp.fc1.bias', 'decoder_blocks.14.mlp.fc2.weight', 'decoder_blocks.14.mlp.fc2.bias', 'decoder_blocks.14.norm2.weight', 'decoder_blocks.14.norm2.bias', 'decoder_blocks.15.attn.tau', 'decoder_blocks.15.attn.qkv.weight', 'decoder_blocks.15.attn.qkv.bias', 'decoder_blocks.15.attn.proj.weight', 'decoder_blocks.15.attn.proj.bias', 'decoder_blocks.15.attn.meta_mlp.fc1.weight', 'decoder_blocks.15.attn.meta_mlp.fc1.bias', 'decoder_blocks.15.attn.meta_mlp.fc2.weight', 'decoder_blocks.15.attn.meta_mlp.fc2.bias', 'decoder_blocks.15.norm1.weight', 'decoder_blocks.15.norm1.bias', 'decoder_blocks.15.mlp.fc1.weight', 'decoder_blocks.15.mlp.fc1.bias', 'decoder_blocks.15.mlp.fc2.weight', 'decoder_blocks.15.mlp.fc2.bias', 'decoder_blocks.15.norm2.weight', 'decoder_blocks.15.norm2.bias', 'decoder_blocks.0.attn.tau', 'decoder_blocks.0.attn.meta_mlp.fc1.weight', 'decoder_blocks.0.attn.meta_mlp.fc1.bias', 'decoder_blocks.0.attn.meta_mlp.fc2.weight', 'decoder_blocks.0.attn.meta_mlp.fc2.bias', 'decoder_blocks.1.attn.tau', 'decoder_blocks.1.attn.meta_mlp.fc1.weight', 'decoder_blocks.1.attn.meta_mlp.fc1.bias', 'decoder_blocks.1.attn.meta_mlp.fc2.weight', 'decoder_blocks.1.attn.meta_mlp.fc2.bias', 'decoder_blocks.2.attn.tau', 'decoder_blocks.2.attn.meta_mlp.fc1.weight', 'decoder_blocks.2.attn.meta_mlp.fc1.bias', 'decoder_blocks.2.attn.meta_mlp.fc2.weight', 'decoder_blocks.2.attn.meta_mlp.fc2.bias', 'decoder_blocks.3.attn.tau', 'decoder_blocks.3.attn.meta_mlp.fc1.weight', 'decoder_blocks.3.attn.meta_mlp.fc1.bias', 'decoder_blocks.3.attn.meta_mlp.fc2.weight', 'decoder_blocks.3.attn.meta_mlp.fc2.bias', 'decoder_blocks.4.attn.tau', 'decoder_blocks.4.attn.meta_mlp.fc1.weight', 'decoder_blocks.4.attn.meta_mlp.fc1.bias', 'decoder_blocks.4.attn.meta_mlp.fc2.weight', 'decoder_blocks.4.attn.meta_mlp.fc2.bias', 'decoder_blocks.5.attn.tau', 'decoder_blocks.5.attn.meta_mlp.fc1.weight', 'decoder_blocks.5.attn.meta_mlp.fc1.bias', 'decoder_blocks.5.attn.meta_mlp.fc2.weight', 'decoder_blocks.5.attn.meta_mlp.fc2.bias', 'decoder_blocks.6.attn.tau', 'decoder_blocks.6.attn.meta_mlp.fc1.weight', 'decoder_blocks.6.attn.meta_mlp.fc1.bias', 'decoder_blocks.6.attn.meta_mlp.fc2.weight', 'decoder_blocks.6.attn.meta_mlp.fc2.bias', 'decoder_blocks.7.attn.tau', 'decoder_blocks.7.attn.meta_mlp.fc1.weight', 'decoder_blocks.7.attn.meta_mlp.fc1.bias', 'decoder_blocks.7.attn.meta_mlp.fc2.weight', 'decoder_blocks.7.attn.meta_mlp.fc2.bias']) Model loaded.

Any thought @berniebear?

@asheff794
Copy link

I believe I found the issue. I was able to run the dev2 demo and I realized that notebook has a decoder_mode flag set to 1 in the prepare_model function. I made that single modification to the demo notebook and am now can load the model without error and get reasonable reconstructions of the output.

def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'):
    # build model
    model = getattr(models_mae, arch)(in_chans=1, audio_exp=True,img_size=(1024,128),decoder_mode=1)
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cuda')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants