-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing model checkpoints to be robust to mutli -> single GPU usage #1091
Changes from 5 commits
d2016d4
de0c4a9
0cc7e28
5c13ae5
7d98ea3
c98a1b7
ece69ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
get_output_attribute, | ||
get_model_attribute, | ||
uses_cuda, | ||
load_model_state, | ||
) # pylint: disable=import-error | ||
from allennlp.nn.util import move_to_device | ||
|
||
|
@@ -1118,7 +1119,6 @@ def _save_checkpoint(self, training_state, phase="pretrain", new_best=False, tas | |
task_dir_name, | ||
"model_state_{}_val_{}{}.th".format(phase, val_pass, best_str), | ||
) | ||
|
||
model_state = self._model.state_dict() | ||
|
||
# Skip non-trainable params, like the main ELMo params. | ||
|
@@ -1217,15 +1217,7 @@ def _restore_checkpoint(self, phase, tasks=None): | |
self._serialization_dir, task_directory, "_".join(["metric", suffix]) | ||
) | ||
|
||
model_state = torch.load(model_path) | ||
|
||
for name, param in self._model.named_parameters(): | ||
if param.requires_grad and name not in model_state: | ||
log.error("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") | ||
log.error("Parameter missing from checkpoint: " + name) | ||
log.error("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") | ||
|
||
self._model.load_state_dict(model_state, strict=False) | ||
load_model_state(self._model, model_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this warning being disabled? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No it's not. This warning is also inside load_model_state function |
||
task_states = torch.load(task_state_path) | ||
for task_name, task_state in task_states.items(): | ||
if task_name == "global": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,9 @@ | |
import glob | ||
import torch | ||
import jsondiff | ||
import collections | ||
|
||
import torch.nn as nn | ||
from allennlp.common.checks import ConfigurationError | ||
from allennlp.common.params import Params | ||
from sacremoses import MosesDetokenizer | ||
|
@@ -323,22 +325,48 @@ def select_task_specific_args(exp_args, diff_args): | |
return diff_args | ||
|
||
|
||
def load_model_state(model, state_path, gpu_id, skip_task_models=[], strict=True): | ||
def get_state_dict_for_loading(model, model_state) -> nn.Module: | ||
""" Function for making sure state dict keys are named appropriately for | ||
multi-GPU and single-GPU use cases | ||
|
||
Parameters | ||
---------- | ||
model: The model object to populate with loaded parameters. | ||
model_state: collections.OrderdDict of model state | ||
""" | ||
final_model_state = collections.OrderedDict() | ||
|
||
def get_key(name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment explaining why we need this logic. |
||
key = name | ||
if isinstance(model, nn.DataParallel): | ||
if "module" not in key: | ||
key = "module.%s" % key | ||
else: | ||
if "module" in key: | ||
key = key.replace("module.", "") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this check explicitly for prefix with |
||
return key | ||
|
||
for name, weights in model_state.items(): | ||
key = get_key(name) | ||
final_model_state[key] = model_state[name] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return final_model_state | ||
|
||
|
||
def load_model_state(model, state_path, skip_task_models=[], strict=True): | ||
""" Helper function to load a model state | ||
|
||
Parameters | ||
---------- | ||
model: The model object to populate with loaded parameters. | ||
state_path: The path to a model_state checkpoint. | ||
gpu_id: The GPU to use. -1 for no GPU. | ||
skip_task_models: If set, skip task-specific parameters for these tasks. | ||
This does not necessarily skip loading ELMo scalar weights, but I (Sam) sincerely | ||
doubt that this matters. | ||
strict: Whether we should fail if any parameters aren't found in the checkpoint. If false, | ||
there is a risk of leaving some parameters in their randomly initialized state. | ||
""" | ||
model_state = torch.load(state_path) | ||
|
||
model_state = get_state_dict_for_loading(model, model_state) | ||
assert_for_log( | ||
not (skip_task_models and strict), | ||
"Can't skip task models while also strictly loading task models. Something is wrong.", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're no longer setting strict=False here. It's debatable whether that's the ideal behavior here, but it was intentional, and I believe it has had some real experimental uses. Why the change?