1
+ from mlagents .trainers .optimizer .torch_optimizer import TorchOptimizer
1
2
import pytest
2
3
from unittest import mock
3
4
import os
6
7
from mlagents .torch_utils import torch , default_device
7
8
from mlagents .trainers .policy .torch_policy import TorchPolicy
8
9
from mlagents .trainers .ppo .optimizer_torch import TorchPPOOptimizer
10
+ from mlagents .trainers .sac .optimizer_torch import TorchSACOptimizer
9
11
from mlagents .trainers .model_saver .torch_model_saver import TorchModelSaver
10
- from mlagents .trainers .settings import TrainerSettings
12
+ from mlagents .trainers .settings import TrainerSettings , PPOSettings , SACSettings
11
13
from mlagents .trainers .tests import mock_brain as mb
12
14
from mlagents .trainers .tests .torch .test_policy import create_policy_mock
13
15
from mlagents .trainers .torch .utils import ModelUtils
@@ -29,7 +31,7 @@ def test_register(tmp_path):
29
31
assert model_saver .policy is not None
30
32
31
33
32
- def test_load_save (tmp_path ):
34
+ def test_load_save_policy (tmp_path ):
33
35
path1 = os .path .join (tmp_path , "runid1" )
34
36
path2 = os .path .join (tmp_path , "runid2" )
35
37
trainer_params = TrainerSettings ()
@@ -62,6 +64,42 @@ def test_load_save(tmp_path):
62
64
assert policy3 .get_current_step () == 0
63
65
64
66
67
+ @pytest .mark .parametrize (
68
+ "optimizer" ,
69
+ [(TorchPPOOptimizer , PPOSettings ), (TorchSACOptimizer , SACSettings )],
70
+ ids = ["ppo" , "sac" ],
71
+ )
72
+ def test_load_save_optimizer (tmp_path , optimizer ):
73
+ OptimizerClass , HyperparametersClass = optimizer
74
+
75
+ trainer_settings = TrainerSettings ()
76
+ trainer_settings .hyperparameters = HyperparametersClass ()
77
+ policy = create_policy_mock (trainer_settings , use_discrete = False )
78
+ optimizer = OptimizerClass (policy , trainer_settings )
79
+
80
+ # save at path 1
81
+ path1 = os .path .join (tmp_path , "runid1" )
82
+ model_saver = TorchModelSaver (trainer_settings , path1 )
83
+ model_saver .register (policy )
84
+ model_saver .register (optimizer )
85
+ model_saver .initialize_or_load ()
86
+ policy .set_step (2000 )
87
+ model_saver .save_checkpoint ("MockBrain" , 2000 )
88
+
89
+ # create a new optimizer and policy
90
+ policy2 = create_policy_mock (trainer_settings , use_discrete = False )
91
+ optimizer2 = OptimizerClass (policy2 , trainer_settings )
92
+
93
+ # load weights
94
+ model_saver2 = TorchModelSaver (trainer_settings , path1 , load = True )
95
+ model_saver2 .register (policy2 )
96
+ model_saver2 .register (optimizer2 )
97
+ model_saver2 .initialize_or_load () # This is to load the optimizers
98
+
99
+ # Compare the two optimizers
100
+ _compare_two_optimizers (optimizer , optimizer2 )
101
+
102
+
65
103
# TorchPolicy.evalute() returns log_probs instead of all_log_probs like tf does.
66
104
# resulting in indeterministic results for testing.
67
105
# So here use sample_actions instead.
@@ -95,6 +133,25 @@ def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None:
95
133
)
96
134
97
135
136
+ def _compare_two_optimizers (opt1 : TorchOptimizer , opt2 : TorchOptimizer ) -> None :
137
+ trajectory = mb .make_fake_trajectory (
138
+ length = 10 ,
139
+ observation_specs = opt1 .policy .behavior_spec .observation_specs ,
140
+ action_spec = opt1 .policy .behavior_spec .action_spec ,
141
+ max_step_complete = True ,
142
+ )
143
+ with torch .no_grad ():
144
+ _ , opt1_val_out , _ = opt1 .get_trajectory_value_estimates (
145
+ trajectory .to_agentbuffer (), trajectory .next_obs , done = False
146
+ )
147
+ _ , opt2_val_out , _ = opt2 .get_trajectory_value_estimates (
148
+ trajectory .to_agentbuffer (), trajectory .next_obs , done = False
149
+ )
150
+
151
+ for opt1_val , opt2_val in zip (opt1_val_out .values (), opt2_val_out .values ()):
152
+ np .testing .assert_array_equal (opt1_val , opt2_val )
153
+
154
+
98
155
@pytest .mark .parametrize ("discrete" , [True , False ], ids = ["discrete" , "continuous" ])
99
156
@pytest .mark .parametrize ("visual" , [True , False ], ids = ["visual" , "vector" ])
100
157
@pytest .mark .parametrize ("rnn" , [True , False ], ids = ["rnn" , "no_rnn" ])
0 commit comments