diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 07e9a383e2..37b9805ba3 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -137,6 +137,23 @@ def test_max_token_len_per_gpu_set_correctly(self): expected_max_token_len, ) + def test_optimizer_config_propagation(self): + config = get_template_config() + config.algorithm.optimizer.lr = 1e-4 + config.algorithm.optimizer.weight_decay = 0.05 + config.check_and_update() + self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr, 1e-4) + self.assertEqual( + config.trainer.trainer_config.actor_rollout_ref.actor.optim.weight_decay, 0.05 + ) + self.assertEqual( + config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_style, "constant" + ) # default value + # critic optimizer should not be affected + self.assertEqual(config.trainer.trainer_config.critic.optim.lr, 1e-5) + self.assertEqual(config.trainer.trainer_config.critic.optim.weight_decay, 0.01) + self.assertEqual(config.trainer.trainer_config.critic.optim.lr_decay_style, "constant") + def tearDown(self): if os.path.exists(CHECKPOINT_ROOT_DIR): shutil.rmtree(CHECKPOINT_ROOT_DIR) diff --git a/trinity/common/config.py b/trinity/common/config.py index ea9f18e8b9..dd328acd55 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -97,6 +97,7 @@ class OptimizerConfig: warmup_style: str = "constant" optimizer_type: str = "adam" betas: List[float] = field(default_factory=lambda: [0.9, 0.999]) + weight_decay: float = 0.01 @dataclass