From d2016d4c24bf6d775a13c089680c020019932108 Mon Sep 17 00:00:00 2001 From: pruksmhc Date: Sat, 16 May 2020 19:59:06 -0700 Subject: [PATCH 1/5] Adding get_model_state --- jiant/models.py | 6 ++++++ jiant/trainer.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/jiant/models.py b/jiant/models.py index 53d6e51b2..20f93ec50 100644 --- a/jiant/models.py +++ b/jiant/models.py @@ -1190,6 +1190,12 @@ def _mc_forward(self, batch, task, predict): out["preds"] = logits.argmax(dim=-1) return out + def get_state_dict_for_saving(self) -> nn.Module: + if isinstance(self, nn.DataParallel): + return self.module.state_dict() + else: + return self.state_dict() + def _lm_only_lr_forward(self, batch, task): """Only left to right pass for LM model - non-bidirectional models. Used for language modeling training only in one direction. diff --git a/jiant/trainer.py b/jiant/trainer.py index c3cb1f6a9..7690c576f 100644 --- a/jiant/trainer.py +++ b/jiant/trainer.py @@ -1107,7 +1107,7 @@ def _save_checkpoint(self, training_state, phase="pretrain", new_best=False, tas "model_state_{}_val_{}{}.th".format(phase, val_pass, best_str), ) - model_state = self._model.state_dict() + model_state = self._model.get_state_dict_for_saving() # Skip non-trainable params, like the main ELMo params. for name, param in self._model.named_parameters(): From de0c4a93a7daa02b1a420917a84d7a84b95b374f Mon Sep 17 00:00:00 2001 From: Yada Pruksachatkun Date: Sun, 17 May 2020 13:29:37 -0400 Subject: [PATCH 2/5] deleting unused parameter and fixing model loading --- jiant/__main__.py | 6 +++--- jiant/trainer.py | 14 +++----------- jiant/utils/utils.py | 26 +++++++++++++++++++++++--- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/jiant/__main__.py b/jiant/__main__.py index f20de7762..1a2d028f0 100644 --- a/jiant/__main__.py +++ b/jiant/__main__.py @@ -506,7 +506,7 @@ def load_model_for_target_train_run(args, ckpt_path, model, strict, task, cuda_d to_train: List of tuples of (name, weight) of trainable parameters """ - load_model_state(model, ckpt_path, cuda_devices, skip_task_models=[task.name], strict=strict) + load_model_state(model, ckpt_path, skip_task_models=[task.name], strict=strict) if args.transfer_paradigm == "finetune": # Train both the task specific models as well as sentence encoder. to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad] @@ -607,11 +607,11 @@ def main(cl_arguments): # including and following that task. last_task_index = [task.name for task in target_tasks_to_train].index(task_to_restore) target_tasks_to_train = target_tasks_to_train[last_task_index:] + for task in target_tasks_to_train: # Skip tasks that should not be trained on. if task.eval_only_task: continue - params_to_train = load_model_for_target_train_run( args, pre_target_train_path, model, strict, task, cuda_device ) @@ -649,7 +649,7 @@ def main(cl_arguments): task_to_use = task_params(task.name).get("use_classifier", task.name) ckpt_path = get_best_checkpoint_path(args, "eval", task_to_use) assert ckpt_path is not None - load_model_state(model, ckpt_path, cuda_device, skip_task_models=[], strict=strict) + load_model_state(model, ckpt_path, skip_task_models=[], strict=strict) evaluate_and_write(args, model, [task], splits_to_write, cuda_device) if args.delete_checkpoints_when_done and not args.keep_all_checkpoints: diff --git a/jiant/trainer.py b/jiant/trainer.py index 7690c576f..605944651 100644 --- a/jiant/trainer.py +++ b/jiant/trainer.py @@ -33,6 +33,7 @@ get_output_attribute, get_model_attribute, uses_cuda, + load_model_state, ) # pylint: disable=import-error from allennlp.nn.util import move_to_device @@ -1106,8 +1107,7 @@ def _save_checkpoint(self, training_state, phase="pretrain", new_best=False, tas task_dir_name, "model_state_{}_val_{}{}.th".format(phase, val_pass, best_str), ) - - model_state = self._model.get_state_dict_for_saving() + model_state = self._model.state_dict() # Skip non-trainable params, like the main ELMo params. for name, param in self._model.named_parameters(): @@ -1205,15 +1205,7 @@ def _restore_checkpoint(self, phase, tasks=None): self._serialization_dir, task_directory, "_".join(["metric", suffix]) ) - model_state = torch.load(model_path) - - for name, param in self._model.named_parameters(): - if param.requires_grad and name not in model_state: - log.error("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - log.error("Parameter missing from checkpoint: " + name) - log.error("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - - self._model.load_state_dict(model_state, strict=False) + load_model_state(self._model, model_path) task_states = torch.load(task_state_path) for task_name, task_state in task_states.items(): if task_name == "global": diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py index 3dca1174c..2f48575d6 100644 --- a/jiant/utils/utils.py +++ b/jiant/utils/utils.py @@ -10,7 +10,9 @@ import glob import torch import jsondiff +import collections +import torch.nn as nn from allennlp.common.checks import ConfigurationError from allennlp.common.params import Params from sacremoses import MosesDetokenizer @@ -323,14 +325,32 @@ def select_task_specific_args(exp_args, diff_args): return diff_args -def load_model_state(model, state_path, gpu_id, skip_task_models=[], strict=True): +def get_state_dict_for_loading(model, model_state) -> nn.Module: + final_model_state = collections.OrderedDict() + + def get_key(name): + key = name + if isinstance(model, nn.DataParallel): + if "module" not in key: + key = "module.%s" % key + else: + if "module" in key: + key = key.replace("module.", "") + return key + + for name, weights in model_state.items(): + key = get_key(name) + final_model_state[key] = model_state[name] + return final_model_state + + +def load_model_state(model, state_path, skip_task_models=[], strict=True): """ Helper function to load a model state Parameters ---------- model: The model object to populate with loaded parameters. state_path: The path to a model_state checkpoint. - gpu_id: The GPU to use. -1 for no GPU. skip_task_models: If set, skip task-specific parameters for these tasks. This does not necessarily skip loading ELMo scalar weights, but I (Sam) sincerely doubt that this matters. @@ -338,7 +358,7 @@ def load_model_state(model, state_path, gpu_id, skip_task_models=[], strict=True there is a risk of leaving some parameters in their randomly initialized state. """ model_state = torch.load(state_path) - + model_state = get_state_dict_for_loading(model, model_state) assert_for_log( not (skip_task_models and strict), "Can't skip task models while also strictly loading task models. Something is wrong.", From 5c13ae5491e1c213877999e40d5b6dcedf3cf75e Mon Sep 17 00:00:00 2001 From: Yada Pruksachatkun Date: Sun, 17 May 2020 13:38:22 -0400 Subject: [PATCH 3/5] Adding documentation --- jiant/__main__.py | 2 +- jiant/models.py | 6 ------ jiant/utils/utils.py | 8 ++++++++ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/jiant/__main__.py b/jiant/__main__.py index 1a2d028f0..cb3a0e952 100644 --- a/jiant/__main__.py +++ b/jiant/__main__.py @@ -607,11 +607,11 @@ def main(cl_arguments): # including and following that task. last_task_index = [task.name for task in target_tasks_to_train].index(task_to_restore) target_tasks_to_train = target_tasks_to_train[last_task_index:] - for task in target_tasks_to_train: # Skip tasks that should not be trained on. if task.eval_only_task: continue + params_to_train = load_model_for_target_train_run( args, pre_target_train_path, model, strict, task, cuda_device ) diff --git a/jiant/models.py b/jiant/models.py index 20f93ec50..53d6e51b2 100644 --- a/jiant/models.py +++ b/jiant/models.py @@ -1190,12 +1190,6 @@ def _mc_forward(self, batch, task, predict): out["preds"] = logits.argmax(dim=-1) return out - def get_state_dict_for_saving(self) -> nn.Module: - if isinstance(self, nn.DataParallel): - return self.module.state_dict() - else: - return self.state_dict() - def _lm_only_lr_forward(self, batch, task): """Only left to right pass for LM model - non-bidirectional models. Used for language modeling training only in one direction. diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py index 2f48575d6..e7e109da9 100644 --- a/jiant/utils/utils.py +++ b/jiant/utils/utils.py @@ -326,6 +326,14 @@ def select_task_specific_args(exp_args, diff_args): def get_state_dict_for_loading(model, model_state) -> nn.Module: + """ Function for making sure state dict keys are named appropriately for + multi-GPU and single-GPU use cases + + Parameters + ---------- + model: The model object to populate with loaded parameters. + model_state: collections.OrderdDict of model state + """ final_model_state = collections.OrderedDict() def get_key(name): From c98a1b7fd9e580426a8d42c846593c8d17384fc4 Mon Sep 17 00:00:00 2001 From: pruksmhc Date: Mon, 18 May 2020 13:38:32 -0700 Subject: [PATCH 4/5] Cleaning up code --- jiant/utils/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py index e7e109da9..24ee2c330 100644 --- a/jiant/utils/utils.py +++ b/jiant/utils/utils.py @@ -342,13 +342,14 @@ def get_key(name): if "module" not in key: key = "module.%s" % key else: - if "module" in key: - key = key.replace("module.", "") + if key.startswith("module."): + # Drop the first 7 characters, which is the prefix "module." + key = key[7:] return key for name, weights in model_state.items(): key = get_key(name) - final_model_state[key] = model_state[name] + final_model_state[key] = weight return final_model_state From ece69ba50c89026d22cdb17bfe5be0c0949a589b Mon Sep 17 00:00:00 2001 From: pruksmhc Date: Mon, 18 May 2020 13:47:02 -0700 Subject: [PATCH 5/5] Fix typo --- jiant/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py index 24ee2c330..a456e1fa1 100644 --- a/jiant/utils/utils.py +++ b/jiant/utils/utils.py @@ -349,7 +349,7 @@ def get_key(name): for name, weights in model_state.items(): key = get_key(name) - final_model_state[key] = weight + final_model_state[key] = weights return final_model_state