Skip to content

Commit

Permalink
Merge pull request ORNL#58 from pzhanggit/model_load
Browse files Browse the repository at this point in the history
update load_existing_model
  • Loading branch information
pzhanggit authored Dec 7, 2021
2 parents 7fa8308 + c8aabb5 commit dffb15c
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions hydragnn/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch_geometric.utils import degree

from hydragnn.utils.distributed import get_comm_size_and_rank, is_model_distributed
from collections import OrderedDict


def get_model_or_module(model):
Expand Down Expand Up @@ -50,6 +51,14 @@ def load_existing_model_config(model, config, path="./logs/"):
def load_existing_model(model, model_name, path="./logs/"):
path_name = os.path.join(path, model_name, model_name + ".pk")
state_dict = torch.load(path_name, map_location="cpu")

if is_model_distributed(model):
ddp_state_dict = OrderedDict()
for k, v in state_dict.items():
k = "module." + k
ddp_state_dict[k] = v
state_dict = ddp_state_dict

model.load_state_dict(state_dict)


Expand Down

0 comments on commit dffb15c

Please sign in to comment.