From d7d8c6e1d14fbc1a6d0405f1fb3b9327df1baa45 Mon Sep 17 00:00:00 2001 From: Alejandro Campoy Nieves Date: Wed, 3 Jul 2024 12:36:34 +0200 Subject: [PATCH] (v3.3.8) - Observation normalization bug Fix (again), negative values in obs_rms.var (#422) * Evl Callback: Using argument train_env instyead oh inhereted training environment * Evl Callback: Fixed mean and var normalization calibration set (now it is applied correctly) * Normalization Wrapper: Deleted RecordConstructorArgs * Normalization wrapper: Deleted RecordConstructorArgs inherit * Normalize Wrapper: Fixed var property bug (returning mean again instead of var) * Updated Sinergym version from 3.3.7 to 3.3.8 --- sinergym/utils/callbacks.py | 10 ++++++---- sinergym/utils/wrappers.py | 8 ++------ sinergym/version.txt | 2 +- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/sinergym/utils/callbacks.py b/sinergym/utils/callbacks.py index bf0a91cda3..8b2182abf1 100644 --- a/sinergym/utils/callbacks.py +++ b/sinergym/utils/callbacks.py @@ -359,7 +359,7 @@ def _on_step(self) -> bool: self._is_success_buffer = [] # We close training env before to start the evaluation - self.training_env.close() + self.train_env.close() self._sync_envs() @@ -375,7 +375,7 @@ def _on_step(self) -> bool: # We close evaluation env and starts training env again self.eval_env.close() - self.training_env.reset() + self.train_env.reset() if self.log_path is not None: for key, value in episodes_data.items(): @@ -502,5 +502,7 @@ def _sync_envs(self): self.eval_env, NormalizeObservation): self.eval_env.get_wrapper_attr('deactivate_update')() - self.eval_env.obs_rms = deepcopy( - self.train_env.get_wrapper_attr('obs_rms')) + self.eval_env.get_wrapper_attr('set_mean')( + self.train_env.get_wrapper_attr('mean')) + self.eval_env.get_wrapper_attr('set_var')( + self.train_env.get_wrapper_attr('var')) diff --git a/sinergym/utils/wrappers.py b/sinergym/utils/wrappers.py index d8d2fef19f..b78af69039 100644 --- a/sinergym/utils/wrappers.py +++ b/sinergym/utils/wrappers.py @@ -56,7 +56,7 @@ def step(self, action: Union[int, np.ndarray]) -> Tuple[ return obs, reward_vector, terminated, truncated, info -class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs): +class NormalizeObservation(gym.Wrapper): logger = Logger().getLogger(name='WRAPPER NormalizeObservation', level=LOG_WRAPPERS_LEVEL) @@ -82,10 +82,6 @@ def __init__(self, mean = self._check_and_update_metric(mean, 'mean') var = self._check_and_update_metric(var, 'var') - # Save normalization configuration for whole python process - gym.utils.RecordConstructorArgs.__init__( - self, epsilon=epsilon, mean=mean, var=var) - self.num_envs = 1 self.is_vector_env = False self.automatic_update = automatic_update @@ -196,7 +192,7 @@ def mean(self) -> Optional[np.float64]: def var(self) -> Optional[np.float64]: """Returns the variance value of the observations.""" if hasattr(self, 'obs_rms'): - return self.obs_rms.mean + return self.obs_rms.var else: return None diff --git a/sinergym/version.txt b/sinergym/version.txt index 010d183f8b..7cb75caa9d 100644 --- a/sinergym/version.txt +++ b/sinergym/version.txt @@ -1 +1 @@ -3.3.7 \ No newline at end of file +3.3.8 \ No newline at end of file