Skip to content

Commit

Permalink
Fix set_env when using VecNormalize (#638)
Browse files Browse the repository at this point in the history
* Fix `set_env` when using `VecNormalize`

* Update version
  • Loading branch information
araffin authored Nov 2, 2021
1 parent 6daf82b commit 2bb4500
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 3 deletions.
9 changes: 7 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.3.1a0 (WIP)
Release 1.3.1a1 (WIP)
---------------------------

Breaking Changes:
Expand All @@ -16,8 +16,10 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Fixed a bug where ``set_env()`` with ``VecNormalize`` would result in an error with off-policy algorithms (thanks @cleversonahum)
- FPS calculation is now performed based on number of steps performed during last ``learn`` call, even when ``reset_num_timesteps`` is set to ``False`` (@kachayev)


Deprecations:
^^^^^^^^^^^^^

Expand Down Expand Up @@ -830,4 +832,7 @@ And all the contributors:
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @eleurent @ac-93
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan
@benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93
4 changes: 4 additions & 0 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,10 @@ def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
env = self._wrap_env(env, self.verbose)
# Check that the observation spaces match
check_for_correct_spaces(env, self.observation_space, self.action_space)
# Update VecNormalize object
# otherwise the wrong env may be used, see https://github.com/DLR-RM/stable-baselines3/issues/637
self._vec_normalize_env = unwrap_vec_normalize(env)

# Discard `_last_obs`, this will force the env to reset before training
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
if force_reset:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.3.1a0
1.3.1a1
7 changes: 7 additions & 0 deletions tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,13 @@ def test_offpolicy_normalization(model_class, online_sampling):
else:
model = model_class("MlpPolicy", env, verbose=1, learning_starts=100, policy_kwargs=dict(net_arch=[64]))

# Check that VecNormalize object is correctly updated
assert model.get_vec_normalize_env() is env
model.set_env(eval_env)
assert model.get_vec_normalize_env() is eval_env
model.learn(total_timesteps=10)
model.set_env(env)

model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75)
# Check getter
assert isinstance(model.get_vec_normalize_env(), VecNormalize)
Expand Down

0 comments on commit 2bb4500

Please sign in to comment.