Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load individual elements if state dict load fails #5213

Merged
merged 22 commits into from
Apr 6, 2021
Merged
Changes from 2 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
49a8a23
load individual elements if state dict load fails
andrewcoh Apr 1, 2021
35980f1
clean up exception catching
andrewcoh Apr 1, 2021
38479f6
use debug statements
andrewcoh Apr 1, 2021
88eff1a
use load_state_dict strict = False
andrewcoh Apr 1, 2021
d67f11f
add unexpected keys
andrewcoh Apr 1, 2021
f6afb5f
fix typo in warning for unexpected keys
andrewcoh Apr 1, 2021
a79718e
add load different reward tests
andrewcoh Apr 1, 2021
7aa6ae4
add debug with error print out
andrewcoh Apr 5, 2021
a663ffa
add doc change
andrewcoh Apr 5, 2021
0230f2c
test convolutions can be loaded properly
andrewcoh Apr 5, 2021
4a57d60
add check that layers still have different dimensions
andrewcoh Apr 5, 2021
46891ec
add special case for non nn.Module load and comment
andrewcoh Apr 5, 2021
d07ff7c
Update ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
andrewcoh Apr 5, 2021
b8bae51
Update ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
andrewcoh Apr 5, 2021
ca8dda8
Update ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
andrewcoh Apr 5, 2021
5a5abaa
Update ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
andrewcoh Apr 5, 2021
e6b87d5
Update docs/Training-ML-Agents.md
andrewcoh Apr 5, 2021
2a20494
Update ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
andrewcoh Apr 5, 2021
5a3e598
update changelog
andrewcoh Apr 5, 2021
cccd32d
Merge branch 'fix-resume-imi' of https://github.com/Unity-Technologie…
andrewcoh Apr 5, 2021
ac5e0e4
update changelog comment
andrewcoh Apr 5, 2021
41c8793
fix typo
andrewcoh Apr 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,25 @@ def _load_model(
policy = cast(TorchPolicy, policy)

for name, mod in modules.items():
mod.load_state_dict(saved_state_dict[name])
try:
mod.load_state_dict(saved_state_dict[name])
except Exception:
andrewcoh marked this conversation as resolved.
Show resolved Hide resolved
if name in saved_state_dict:
logger.warning(f"Failed to load directly for module {name}.")
andrewcoh marked this conversation as resolved.
Show resolved Hide resolved
for mod_element in mod.state_dict():
try:
logger.warning(f"Copying {mod_element}")
mod.state_dict()[mod_element].copy_(
saved_state_dict[name][mod_element]
)
except Exception:
logger.warning(
f"{mod_element} was not found or is not loadable (changed shape). Initializing."
)
else:
logger.warning(
f"The module {name} was not found in the checkpoint. Initializing."
)

if reset_global_steps:
policy.set_step(0)
Expand Down