Skip to content

Commit 78cb833

Browse files
author
Ervin T
authored
[bug-fix] Fix save/restore critic, add test (#5062)
* Fix save/restore critic, add test * Rename module for PPO * Use correct policy in test
1 parent dd6575d commit 78cb833

File tree

3 files changed

+65
-4
lines changed

3 files changed

+65
-4
lines changed

ml-agents/mlagents/trainers/ppo/optimizer_torch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
232232
return update_stats
233233

234234
def get_modules(self):
235-
modules = {"Optimizer": self.optimizer}
235+
modules = {
236+
"Optimizer:value_optimizer": self.optimizer,
237+
"Optimizer:critic": self._critic,
238+
}
236239
for reward_provider in self.reward_signals.values():
237240
modules.update(reward_provider.get_modules())
238241
return modules

ml-agents/mlagents/trainers/sac/optimizer_torch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,8 @@ def update_reward_signals(
635635

636636
def get_modules(self):
637637
modules = {
638-
"Optimizer:value_network": self.q_network,
638+
"Optimizer:q_network": self.q_network,
639+
"Optimizer:value_network": self._critic,
639640
"Optimizer:target_network": self.target_network,
640641
"Optimizer:policy_optimizer": self.policy_optimizer,
641642
"Optimizer:value_optimizer": self.value_optimizer,

ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
12
import pytest
23
from unittest import mock
34
import os
@@ -6,8 +7,9 @@
67
from mlagents.torch_utils import torch, default_device
78
from mlagents.trainers.policy.torch_policy import TorchPolicy
89
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
10+
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
911
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
10-
from mlagents.trainers.settings import TrainerSettings
12+
from mlagents.trainers.settings import TrainerSettings, PPOSettings, SACSettings
1113
from mlagents.trainers.tests import mock_brain as mb
1214
from mlagents.trainers.tests.torch.test_policy import create_policy_mock
1315
from mlagents.trainers.torch.utils import ModelUtils
@@ -29,7 +31,7 @@ def test_register(tmp_path):
2931
assert model_saver.policy is not None
3032

3133

32-
def test_load_save(tmp_path):
34+
def test_load_save_policy(tmp_path):
3335
path1 = os.path.join(tmp_path, "runid1")
3436
path2 = os.path.join(tmp_path, "runid2")
3537
trainer_params = TrainerSettings()
@@ -62,6 +64,42 @@ def test_load_save(tmp_path):
6264
assert policy3.get_current_step() == 0
6365

6466

67+
@pytest.mark.parametrize(
68+
"optimizer",
69+
[(TorchPPOOptimizer, PPOSettings), (TorchSACOptimizer, SACSettings)],
70+
ids=["ppo", "sac"],
71+
)
72+
def test_load_save_optimizer(tmp_path, optimizer):
73+
OptimizerClass, HyperparametersClass = optimizer
74+
75+
trainer_settings = TrainerSettings()
76+
trainer_settings.hyperparameters = HyperparametersClass()
77+
policy = create_policy_mock(trainer_settings, use_discrete=False)
78+
optimizer = OptimizerClass(policy, trainer_settings)
79+
80+
# save at path 1
81+
path1 = os.path.join(tmp_path, "runid1")
82+
model_saver = TorchModelSaver(trainer_settings, path1)
83+
model_saver.register(policy)
84+
model_saver.register(optimizer)
85+
model_saver.initialize_or_load()
86+
policy.set_step(2000)
87+
model_saver.save_checkpoint("MockBrain", 2000)
88+
89+
# create a new optimizer and policy
90+
policy2 = create_policy_mock(trainer_settings, use_discrete=False)
91+
optimizer2 = OptimizerClass(policy2, trainer_settings)
92+
93+
# load weights
94+
model_saver2 = TorchModelSaver(trainer_settings, path1, load=True)
95+
model_saver2.register(policy2)
96+
model_saver2.register(optimizer2)
97+
model_saver2.initialize_or_load() # This is to load the optimizers
98+
99+
# Compare the two optimizers
100+
_compare_two_optimizers(optimizer, optimizer2)
101+
102+
65103
# TorchPolicy.evalute() returns log_probs instead of all_log_probs like tf does.
66104
# resulting in indeterministic results for testing.
67105
# So here use sample_actions instead.
@@ -95,6 +133,25 @@ def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None:
95133
)
96134

97135

136+
def _compare_two_optimizers(opt1: TorchOptimizer, opt2: TorchOptimizer) -> None:
137+
trajectory = mb.make_fake_trajectory(
138+
length=10,
139+
observation_specs=opt1.policy.behavior_spec.observation_specs,
140+
action_spec=opt1.policy.behavior_spec.action_spec,
141+
max_step_complete=True,
142+
)
143+
with torch.no_grad():
144+
_, opt1_val_out, _ = opt1.get_trajectory_value_estimates(
145+
trajectory.to_agentbuffer(), trajectory.next_obs, done=False
146+
)
147+
_, opt2_val_out, _ = opt2.get_trajectory_value_estimates(
148+
trajectory.to_agentbuffer(), trajectory.next_obs, done=False
149+
)
150+
151+
for opt1_val, opt2_val in zip(opt1_val_out.values(), opt2_val_out.values()):
152+
np.testing.assert_array_equal(opt1_val, opt2_val)
153+
154+
98155
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
99156
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
100157
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])

0 commit comments

Comments
 (0)