Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing model checkpoints to be robust to mutli -> single GPU usage #1091

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jiant/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def load_model_for_target_train_run(args, ckpt_path, model, strict, task, cuda_d
to_train: List of tuples of (name, weight) of trainable parameters

"""
load_model_state(model, ckpt_path, cuda_devices, skip_task_models=[task.name], strict=strict)
load_model_state(model, ckpt_path, skip_task_models=[task.name], strict=strict)
if args.transfer_paradigm == "finetune":
# Train both the task specific models as well as sentence encoder.
to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
Expand Down Expand Up @@ -675,7 +675,7 @@ def main(cl_arguments):
task_to_use = task_params(task.name).get("use_classifier", task.name)
ckpt_path = get_best_checkpoint_path(args, "eval", task_to_use)
assert ckpt_path is not None
load_model_state(model, ckpt_path, cuda_device, skip_task_models=[], strict=strict)
load_model_state(model, ckpt_path, skip_task_models=[], strict=strict)
evaluate_and_write(args, model, [task], splits_to_write, cuda_device)

if args.delete_checkpoints_when_done and not args.keep_all_checkpoints:
Expand Down
12 changes: 2 additions & 10 deletions jiant/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
get_output_attribute,
get_model_attribute,
uses_cuda,
load_model_state,
) # pylint: disable=import-error
from allennlp.nn.util import move_to_device

Expand Down Expand Up @@ -1118,7 +1119,6 @@ def _save_checkpoint(self, training_state, phase="pretrain", new_best=False, tas
task_dir_name,
"model_state_{}_val_{}{}.th".format(phase, val_pass, best_str),
)

model_state = self._model.state_dict()

# Skip non-trainable params, like the main ELMo params.
Expand Down Expand Up @@ -1217,15 +1217,7 @@ def _restore_checkpoint(self, phase, tasks=None):
self._serialization_dir, task_directory, "_".join(["metric", suffix])
)

model_state = torch.load(model_path)

for name, param in self._model.named_parameters():
if param.requires_grad and name not in model_state:
log.error("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
log.error("Parameter missing from checkpoint: " + name)
log.error("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

self._model.load_state_dict(model_state, strict=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're no longer setting strict=False here. It's debatable whether that's the ideal behavior here, but it was intentional, and I believe it has had some real experimental uses. Why the change?

load_model_state(self._model, model_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this warning being disabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's not. This warning is also inside load_model_state function

task_states = torch.load(task_state_path)
for task_name, task_state in task_states.items():
if task_name == "global":
Expand Down
34 changes: 31 additions & 3 deletions jiant/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import glob
import torch
import jsondiff
import collections

import torch.nn as nn
from allennlp.common.checks import ConfigurationError
from allennlp.common.params import Params
from sacremoses import MosesDetokenizer
Expand Down Expand Up @@ -323,22 +325,48 @@ def select_task_specific_args(exp_args, diff_args):
return diff_args


def load_model_state(model, state_path, gpu_id, skip_task_models=[], strict=True):
def get_state_dict_for_loading(model, model_state) -> nn.Module:
""" Function for making sure state dict keys are named appropriately for
multi-GPU and single-GPU use cases

Parameters
----------
model: The model object to populate with loaded parameters.
model_state: collections.OrderdDict of model state
"""
final_model_state = collections.OrderedDict()

def get_key(name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment explaining why we need this logic.

key = name
if isinstance(model, nn.DataParallel):
if "module" not in key:
key = "module.%s" % key
else:
if "module" in key:
key = key.replace("module.", "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this check explicitly for prefix with .startswith, and drop the first n characters (in case module appears somewhere else in the parameter name).

return key

for name, weights in model_state.items():
key = get_key(name)
final_model_state[key] = model_state[name]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

= weights

return final_model_state


def load_model_state(model, state_path, skip_task_models=[], strict=True):
""" Helper function to load a model state

Parameters
----------
model: The model object to populate with loaded parameters.
state_path: The path to a model_state checkpoint.
gpu_id: The GPU to use. -1 for no GPU.
skip_task_models: If set, skip task-specific parameters for these tasks.
This does not necessarily skip loading ELMo scalar weights, but I (Sam) sincerely
doubt that this matters.
strict: Whether we should fail if any parameters aren't found in the checkpoint. If false,
there is a risk of leaving some parameters in their randomly initialized state.
"""
model_state = torch.load(state_path)

model_state = get_state_dict_for_loading(model, model_state)
assert_for_log(
not (skip_task_models and strict),
"Can't skip task models while also strictly loading task models. Something is wrong.",
Expand Down