Skip to content

Commit

Permalink
Merge pull request #33 from uncbiag/weights_only_patch
Browse files Browse the repository at this point in the history
Fix weights only warning message
  • Loading branch information
HastingsGreer authored Dec 11, 2024
2 parents 19cd155 + 15db795 commit f2781b5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/unigradicon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def get_multigradicon(loss_fn=icon.LNCC(sigma=5)):
os.makedirs("network_weights/multigradicon1.0/", exist_ok=True)
urllib.request.urlretrieve(download_path, weights_location)
print(f"Loading weights from {weights_location}")
trained_weights = torch.load(weights_location, map_location=torch.device("cpu"))
trained_weights = torch.load(weights_location, map_location=torch.device("cpu"), weights_only=True)
net.regis_net.load_state_dict(trained_weights)
net.to(config.device)
net.eval()
Expand All @@ -223,7 +223,7 @@ def get_unigradicon(loss_fn=icon.LNCC(sigma=5)):
download_path = "https://github.com/uncbiag/uniGradICON/releases/download/unigradicon_weights/Step_2_final.trch"
os.makedirs("network_weights/unigradicon1.0/", exist_ok=True)
urllib.request.urlretrieve(download_path, weights_location)
trained_weights = torch.load(weights_location, map_location=torch.device("cpu"))
trained_weights = torch.load(weights_location, map_location=torch.device("cpu"), weights_only=True)
net.regis_net.load_state_dict(trained_weights)
net.to(config.device)
net.eval()
Expand Down

0 comments on commit f2781b5

Please sign in to comment.