From c8aabb5c327b33790d81c2a238eed2e95b1b463b Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Mon, 29 Nov 2021 17:13:30 -0500 Subject: [PATCH] model.load_state_dict update --- hydragnn/utils/model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/hydragnn/utils/model.py b/hydragnn/utils/model.py index aff95740b..ed50ff700 100644 --- a/hydragnn/utils/model.py +++ b/hydragnn/utils/model.py @@ -17,6 +17,7 @@ from torch_geometric.utils import degree from hydragnn.utils.distributed import get_comm_size_and_rank, is_model_distributed +from collections import OrderedDict def get_model_or_module(model): @@ -50,6 +51,14 @@ def load_existing_model_config(model, config, path="./logs/"): def load_existing_model(model, model_name, path="./logs/"): path_name = os.path.join(path, model_name, model_name + ".pk") state_dict = torch.load(path_name, map_location="cpu") + + if is_model_distributed(model): + ddp_state_dict = OrderedDict() + for k, v in state_dict.items(): + k = "module." + k + ddp_state_dict[k] = v + state_dict = ddp_state_dict + model.load_state_dict(state_dict)