-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing model checkpoints to be robust to mutli -> single GPU usage #1091
Conversation
Hello @pruksmhc! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:
You can repair most issues by installing black and running: Comment last updated at 2020-05-18 20:46:53 UTC |
…into fix_multi_to_single
jiant/utils/utils.py
Outdated
if "module" in key: | ||
key = key.replace("module.", "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this check explicitly for prefix with .startswith
, and drop the first n characters (in case module appears somewhere else in the parameter name).
jiant/utils/utils.py
Outdated
|
||
for name, weights in model_state.items(): | ||
key = get_key(name) | ||
final_model_state[key] = model_state[name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
= weights
log.error("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") | ||
|
||
self._model.load_state_dict(model_state, strict=False) | ||
load_model_state(self._model, model_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this warning being disabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it's not. This warning is also inside load_model_state function
Hi @zphang & @HaokunLiu — are either of you available to provide the substantial review for this PR? The core concerns seem to be 1) whether this addresses issue #1087, and 2) whether these changes introduce new risks/regressions. |
""" | ||
final_model_state = collections.OrderedDict() | ||
|
||
def get_key(name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment explaining why we need this logic.
log.error("Parameter missing from checkpoint: " + name) | ||
log.error("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") | ||
|
||
self._model.load_state_dict(model_state, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're no longer setting strict=False here. It's debatable whether that's the ideal behavior here, but it was intentional, and I believe it has had some real experimental uses. Why the change?
Are these changes still necessary? Planning to close all PRs to move jiant2 to this repo in the near future. |
This is a fix to #1087. I decided to make the change in the model loading portion because making the change in model saving as suggested in the #1087 will fix multi -> single GPU model loading, but will break multi -> multi GPU model loading (if we want to reload a checkpoint that was trained in multi-GPU on a multi-GPU machine).
Additionally, I also did some light cleanup of model loading in trainer to not be redundant, and also deleted an unused parameter.
Tests
Multi -> Single GPU: I tested by training a roberta-large model on SST on multi-GPU, and then loading that checkpoint in a single-GPU for further training.
Multi -> Multi GPU: This is implicitly already done in jiant, specifically we load the best checkpoint before doing evaluation, so this was tested when I trained the roberta-large SST model the first time on multi-GPU.