From 5f8a5f91c0259225312d18edd2b0c005062a4fd9 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 20 Aug 2024 14:07:54 +0800 Subject: [PATCH 1/3] update load function Signed-off-by: Yiheng Wang --- monai/networks/nets/diffusion_model_unet.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index f57fe251d2..bb882759aa 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -1837,11 +1837,18 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k] + attention_blocks = [k.replace(".attn.to_k.weight", "") for k in new_state_dict if "attn.to_k.weight" in k] for block in attention_blocks: + new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight") + new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight") + new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight") + new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias") + new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") + new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") + # projection - new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") - new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") + new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight") + new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias") # fix the upsample conv blocks which were renamed postconv for k in new_state_dict: From 4839feb7756fd3affcdd213ed9e197c014f322a7 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 20 Aug 2024 15:31:16 +0800 Subject: [PATCH 2/3] add cross attention Signed-off-by: Yiheng Wang --- monai/networks/nets/diffusion_model_unet.py | 10 ++++++++++ tests/test_diffusion_model_unet.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index bb882759aa..65d6053acc 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -1850,6 +1850,16 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight") new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias") + # fix the cross attention blocks + cross_attention_blocks = [ + k.replace(".out_proj.weight", "") + for k in new_state_dict + if "out_proj.weight" in k and "transformer_blocks" in k + ] + for block in cross_attention_blocks: + new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") + new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") + # fix the upsample conv blocks which were renamed postconv for k in new_state_dict: if "postconv" in k: diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index 7f764d85de..db3df5bcb6 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -578,7 +578,7 @@ def test_compatibility_with_monai_generative(self): weight_path = os.path.join(tmpdir, filename) download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) - net.load_old_state_dict(torch.load(weight_path), verbose=False) + net.load_old_state_dict(torch.load(weight_path), verbose=True) if __name__ == "__main__": From 412af6ea80771817d37f5b07421d9572b32dd400 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:40:23 +0800 Subject: [PATCH 3/3] Update tests/test_diffusion_model_unet.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> --- tests/test_diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index db3df5bcb6..7f764d85de 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -578,7 +578,7 @@ def test_compatibility_with_monai_generative(self): weight_path = os.path.join(tmpdir, filename) download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) - net.load_old_state_dict(torch.load(weight_path), verbose=True) + net.load_old_state_dict(torch.load(weight_path), verbose=False) if __name__ == "__main__":