diff --git a/jiant/__main__.py b/jiant/__main__.py index c7a7cd87e..c3fc8daac 100644 --- a/jiant/__main__.py +++ b/jiant/__main__.py @@ -506,7 +506,7 @@ def load_model_for_target_train_run(args, ckpt_path, model, strict, task, cuda_d to_train: List of tuples of (name, weight) of trainable parameters """ - load_model_state(model, ckpt_path, cuda_devices, skip_task_models=[task.name], strict=strict) + load_model_state(model, ckpt_path, skip_task_models=[task.name], strict=strict) if args.transfer_paradigm == "finetune": # Train both the task specific models as well as sentence encoder. to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad] @@ -675,7 +675,7 @@ def main(cl_arguments): task_to_use = task_params(task.name).get("use_classifier", task.name) ckpt_path = get_best_checkpoint_path(args, "eval", task_to_use) assert ckpt_path is not None - load_model_state(model, ckpt_path, cuda_device, skip_task_models=[], strict=strict) + load_model_state(model, ckpt_path, skip_task_models=[], strict=strict) evaluate_and_write(args, model, [task], splits_to_write, cuda_device) if args.delete_checkpoints_when_done and not args.keep_all_checkpoints: diff --git a/jiant/trainer.py b/jiant/trainer.py index c50ff7f7d..c9cf7d124 100644 --- a/jiant/trainer.py +++ b/jiant/trainer.py @@ -33,6 +33,7 @@ get_output_attribute, get_model_attribute, uses_cuda, + load_model_state, ) # pylint: disable=import-error from allennlp.nn.util import move_to_device @@ -1118,7 +1119,6 @@ def _save_checkpoint(self, training_state, phase="pretrain", new_best=False, tas task_dir_name, "model_state_{}_val_{}{}.th".format(phase, val_pass, best_str), ) - model_state = self._model.state_dict() # Skip non-trainable params, like the main ELMo params. @@ -1217,15 +1217,7 @@ def _restore_checkpoint(self, phase, tasks=None): self._serialization_dir, task_directory, "_".join(["metric", suffix]) ) - model_state = torch.load(model_path) - - for name, param in self._model.named_parameters(): - if param.requires_grad and name not in model_state: - log.error("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - log.error("Parameter missing from checkpoint: " + name) - log.error("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - - self._model.load_state_dict(model_state, strict=False) + load_model_state(self._model, model_path) task_states = torch.load(task_state_path) for task_name, task_state in task_states.items(): if task_name == "global": diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py index 3dca1174c..a456e1fa1 100644 --- a/jiant/utils/utils.py +++ b/jiant/utils/utils.py @@ -10,7 +10,9 @@ import glob import torch import jsondiff +import collections +import torch.nn as nn from allennlp.common.checks import ConfigurationError from allennlp.common.params import Params from sacremoses import MosesDetokenizer @@ -323,14 +325,41 @@ def select_task_specific_args(exp_args, diff_args): return diff_args -def load_model_state(model, state_path, gpu_id, skip_task_models=[], strict=True): +def get_state_dict_for_loading(model, model_state) -> nn.Module: + """ Function for making sure state dict keys are named appropriately for + multi-GPU and single-GPU use cases + + Parameters + ---------- + model: The model object to populate with loaded parameters. + model_state: collections.OrderdDict of model state + """ + final_model_state = collections.OrderedDict() + + def get_key(name): + key = name + if isinstance(model, nn.DataParallel): + if "module" not in key: + key = "module.%s" % key + else: + if key.startswith("module."): + # Drop the first 7 characters, which is the prefix "module." + key = key[7:] + return key + + for name, weights in model_state.items(): + key = get_key(name) + final_model_state[key] = weights + return final_model_state + + +def load_model_state(model, state_path, skip_task_models=[], strict=True): """ Helper function to load a model state Parameters ---------- model: The model object to populate with loaded parameters. state_path: The path to a model_state checkpoint. - gpu_id: The GPU to use. -1 for no GPU. skip_task_models: If set, skip task-specific parameters for these tasks. This does not necessarily skip loading ELMo scalar weights, but I (Sam) sincerely doubt that this matters. @@ -338,7 +367,7 @@ def load_model_state(model, state_path, gpu_id, skip_task_models=[], strict=True there is a risk of leaving some parameters in their randomly initialized state. """ model_state = torch.load(state_path) - + model_state = get_state_dict_for_loading(model, model_state) assert_for_log( not (skip_task_models and strict), "Can't skip task models while also strictly loading task models. Something is wrong.",