Skip to content

Commit

Permalink
Load individual elements if state dict load fails (#5213)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com>
Co-authored-by: Ervin T. <ervin@unity3d.com>
  • Loading branch information
3 people authored Apr 6, 2021
1 parent 30fde2d commit ac4f43c
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 3 deletions.
4 changes: 4 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ sizes and will need to be retrained. (#5181)
different sizes using the same model. For a summary of the interface changes, please see the Migration Guide. (##5189)

#### ml-agents / ml-agents-envs / gym-unity (Python)
- The `--resume` flag now supports resuming experiments with additional reward providers or
loading partial models if the network architecture has changed. See
[here](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Training-ML-Agents.md#loading-an-existing-model)
for more details. (#5213)

### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
Expand Down
7 changes: 7 additions & 0 deletions docs/Training-ML-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ Python by using both the `--resume` and `--inference` flags. Note that if you
want to run inference in Unity, you should use the
[Unity Inference Engine](Getting-Started.md#running-a-pre-trained-model).

Additionally, if the network architecture changes, you may still load an existing model,
but ML-Agents will only load the parts of the model it can load and ignore all others. For instance,
if you add a new reward signal, the existing model will load but the new reward signal
will be initialized from scratch. If you have a model with a visual encoder (CNN) but
change the `hidden_units`, the CNN will be loaded but the body of the network will be
initialized from scratch.

Alternatively, you might want to start a new training run but _initialize_ it
using an already-trained model. You may want to do this, for instance, if your
environment changed and you want a new model, but the old behavior is still
Expand Down
29 changes: 28 additions & 1 deletion ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,34 @@ def _load_model(
policy = cast(TorchPolicy, policy)

for name, mod in modules.items():
mod.load_state_dict(saved_state_dict[name])
try:
if isinstance(mod, torch.nn.Module):
missing_keys, unexpected_keys = mod.load_state_dict(
saved_state_dict[name], strict=False
)
if missing_keys:
logger.warning(
f"Did not find these keys {missing_keys} in checkpoint. Initializing."
)
if unexpected_keys:
logger.warning(
f"Did not expect these keys {unexpected_keys} in checkpoint. Ignoring."
)
else:
# If module is not an nn.Module, try to load as one piece
mod.load_state_dict(saved_state_dict[name])

# KeyError is raised if the module was not present in the last run but is being
# accessed in the saved_state_dict.
# ValueError is raised by the optimizer's load_state_dict if the parameters have
# have changed. Note, the optimizer uses a completely different load_state_dict
# function because it is not an nn.Module.
# RuntimeError is raised by PyTorch if there is a size mismatch between modules
# of the same name. This will still partially assign values to those layers that
# have not changed shape.
except (KeyError, ValueError, RuntimeError) as err:
logger.warning(f"Failed to load for module {name}. Initializing")
logger.debug(f"Module loading error : {err}")

if reset_global_steps:
policy.set_step(0)
Expand Down
46 changes: 46 additions & 0 deletions ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
from mlagents.trainers.settings import (
TrainerSettings,
NetworkSettings,
EncoderType,
PPOSettings,
SACSettings,
POCASettings,
Expand Down Expand Up @@ -70,6 +72,50 @@ def test_load_save_policy(tmp_path):
assert policy3.get_current_step() == 0


@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
def test_load_policy_different_hidden_units(tmp_path, vis_encode_type):
path1 = os.path.join(tmp_path, "runid1")
trainer_params = TrainerSettings()
trainer_params.network_settings = NetworkSettings(
hidden_units=12, vis_encode_type=EncoderType(vis_encode_type)
)
policy = create_policy_mock(trainer_params, use_visual=True)
conv_params = [mod for mod in policy.actor.parameters() if len(mod.shape) > 2]

model_saver = TorchModelSaver(trainer_params, path1)
model_saver.register(policy)
model_saver.initialize_or_load(policy)
policy.set_step(2000)

mock_brain_name = "MockBrain"
model_saver.save_checkpoint(mock_brain_name, 2000)

# Try load from this path
trainer_params2 = TrainerSettings()
trainer_params2.network_settings = NetworkSettings(
hidden_units=10, vis_encode_type=EncoderType(vis_encode_type)
)
model_saver2 = TorchModelSaver(trainer_params2, path1, load=True)
policy2 = create_policy_mock(trainer_params2, use_visual=True)
conv_params2 = [mod for mod in policy2.actor.parameters() if len(mod.shape) > 2]
# asserts convolutions have different parameters before load
for conv1, conv2 in zip(conv_params, conv_params2):
assert not torch.equal(conv1, conv2)
# asserts layers still have different dimensions
for mod1, mod2 in zip(policy.actor.parameters(), policy2.actor.parameters()):
if mod1.shape[0] == 12:
assert mod2.shape[0] == 10
model_saver2.register(policy2)
model_saver2.initialize_or_load(policy2)
# asserts convolutions have same parameters after load
for conv1, conv2 in zip(conv_params, conv_params2):
assert torch.equal(conv1, conv2)
# asserts layers still have different dimensions
for mod1, mod2 in zip(policy.actor.parameters(), policy2.actor.parameters()):
if mod1.shape[0] == 12:
assert mod2.shape[0] == 10


@pytest.mark.parametrize(
"optimizer",
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import numpy as np

from mlagents_envs.logging_util import WARNING
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
from mlagents.trainers.settings import (
TrainerSettings,
Expand All @@ -14,12 +16,14 @@
RNDSettings,
PPOSettings,
SACSettings,
POCASettings,
)
from mlagents.trainers.tests.torch.test_policy import create_policy_mock
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,
)


DEMO_PATH = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir)
+ "/test.demo"
Expand All @@ -28,8 +32,12 @@

@pytest.mark.parametrize(
"optimizer",
[(TorchPPOOptimizer, PPOSettings), (TorchSACOptimizer, SACSettings)],
ids=["ppo", "sac"],
[
(TorchPPOOptimizer, PPOSettings),
(TorchSACOptimizer, SACSettings),
(TorchPOCAOptimizer, POCASettings),
],
ids=["ppo", "sac", "poca"],
)
def test_reward_provider_save(tmp_path, optimizer):
OptimizerClass, HyperparametersClass = optimizer
Expand Down Expand Up @@ -87,3 +95,55 @@ def test_reward_provider_save(tmp_path, optimizer):
rp_1 = optimizer.reward_signals[reward_name]
rp_2 = optimizer2.reward_signals[reward_name]
assert np.array_equal(rp_1.evaluate(data), rp_2.evaluate(data))


@pytest.mark.parametrize(
"optimizer",
[
(TorchPPOOptimizer, PPOSettings),
(TorchSACOptimizer, SACSettings),
(TorchPOCAOptimizer, POCASettings),
],
ids=["ppo", "sac", "poca"],
)
def test_load_different_reward_provider(caplog, tmp_path, optimizer):
OptimizerClass, HyperparametersClass = optimizer

trainer_settings = TrainerSettings()
trainer_settings.hyperparameters = HyperparametersClass()
trainer_settings.reward_signals = {
RewardSignalType.CURIOSITY: CuriositySettings(),
RewardSignalType.RND: RNDSettings(),
}

policy = create_policy_mock(trainer_settings, use_discrete=False)
optimizer = OptimizerClass(policy, trainer_settings)

# save at path 1
path1 = os.path.join(tmp_path, "runid1")
model_saver = TorchModelSaver(trainer_settings, path1)
model_saver.register(policy)
model_saver.register(optimizer)
model_saver.initialize_or_load()
assert len(optimizer.critic.value_heads.stream_names) == 2
policy.set_step(2000)
model_saver.save_checkpoint("MockBrain", 2000)

trainer_settings2 = TrainerSettings()
trainer_settings2.hyperparameters = HyperparametersClass()
trainer_settings2.reward_signals = {
RewardSignalType.GAIL: GAILSettings(demo_path=DEMO_PATH)
}

# create a new optimizer and policy
policy2 = create_policy_mock(trainer_settings2, use_discrete=False)
optimizer2 = OptimizerClass(policy2, trainer_settings2)

# load weights
model_saver2 = TorchModelSaver(trainer_settings2, path1, load=True)
model_saver2.register(policy2)
model_saver2.register(optimizer2)
assert len(optimizer2.critic.value_heads.stream_names) == 1
model_saver2.initialize_or_load() # This is to load the optimizers
messages = [rec.message for rec in caplog.records if rec.levelno == WARNING]
assert len(messages) > 0

0 comments on commit ac4f43c

Please sign in to comment.