Skip to content

Commit

Permalink
support observation normalization in BaseVectorEnv (#308)
Browse files Browse the repository at this point in the history
add RunningMeanStd
  • Loading branch information
ChenDRAG authored Mar 11, 2021
1 parent 5c53f8c commit 243ab43
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 43 deletions.
13 changes: 12 additions & 1 deletion test/base/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import numpy as np

from tianshou.utils import MovAvg
from tianshou.utils.net.common import MLP, Net
from tianshou.utils import MovAvg, RunningMeanStd
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic

Expand Down Expand Up @@ -30,6 +30,16 @@ def test_moving_average():
assert np.allclose(stat.std() ** 2, 2)


def test_rms():
rms = RunningMeanStd()
assert np.allclose(rms.mean, 0)
assert np.allclose(rms.var, 1)
rms.update(np.array([[[1, 2], [3, 5]]]))
rms.update(np.array([[[1, 2], [3, 4]], [[1, 2], [0, 0]]]))
assert np.allclose(rms.mean, np.array([[1, 2], [2, 3]]), atol=1e-3)
assert np.allclose(rms.var, np.array([[0, 0], [2, 14 / 3.]]), atol=1e-3)


def test_net():
# here test the networks that does not appear in the other script
bsz = 64
Expand Down Expand Up @@ -79,4 +89,5 @@ def test_net():
if __name__ == '__main__':
test_noise()
test_moving_average()
test_rms()
test_net()
88 changes: 47 additions & 41 deletions tianshou/env/venvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from typing import Any, List, Union, Optional, Callable

from tianshou.utils import RunningMeanStd
from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
RayEnvWorker

Expand Down Expand Up @@ -55,6 +56,13 @@ def seed(self, seed):
:param float timeout: use in asynchronous simulation same as above, in each
vectorized step it only deal with those environments spending time
within ``timeout`` seconds.
:param bool norm_obs: Whether to track mean/std of data and normalise observation
on return. For now, observation normalization only support observation of
type np.ndarray.
:param obs_rms: class to track mean&std of observation. If not given, it will
initialize a new one. Usually in envs that is used to evaluate algorithm,
obs_rms should be passed in. Default to None.
:param bool update_obs_rms: Whether to update obs_rms. Default to True.
"""

def __init__(
Expand All @@ -63,6 +71,9 @@ def __init__(
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
norm_obs: bool = False,
obs_rms: Optional[RunningMeanStd] = None,
update_obs_rms: bool = True,
) -> None:
self._env_fns = env_fns
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
Expand Down Expand Up @@ -90,6 +101,12 @@ def __init__(
self.ready_id = list(range(self.env_num))
self.is_closed = False

# initialize observation running mean/std
self.norm_obs = norm_obs
self.update_obs_rms = update_obs_rms
self.obs_rms = RunningMeanStd() if obs_rms is None and norm_obs else obs_rms
self.__eps = np.finfo(np.float32).eps.item()

def _assert_is_not_closed(self) -> None:
assert not self.is_closed, \
f"Methods of {self.__class__.__name__} cannot be called after close."
Expand Down Expand Up @@ -149,7 +166,9 @@ def reset(
if self.is_async:
self._assert_id(id)
obs = np.stack([self.workers[i].reset() for i in id])
return obs
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs)
return self.normalize_obs(obs)

def step(
self,
Expand Down Expand Up @@ -219,7 +238,10 @@ def step(
info["env_id"] = env_id
result.append((obs, rew, done, info))
self.ready_id.append(env_id)
return list(map(np.stack, zip(*result)))
obs_stack, rew_stack, done_stack, info_stack = map(np.stack, zip(*result))
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs_stack)
return [self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack]

def seed(
self, seed: Optional[Union[int, List[int]]] = None
Expand Down Expand Up @@ -255,15 +277,23 @@ def render(self, **kwargs: Any) -> List[Any]:
def close(self) -> None:
"""Close all of the environments.
This function will be called only once (if not, it will be called
during garbage collected). This way, ``close`` of all workers can be
assured.
This function will be called only once (if not, it will be called during
garbage collected). This way, ``close`` of all workers can be assured.
"""
self._assert_is_not_closed()
for w in self.workers:
w.close()
self.is_closed = True

def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
"""Normalize observations by statistics in obs_rms."""
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
if self.obs_rms and self.norm_obs:
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
obs = np.clip(obs, -clip_max, clip_max)
return obs

def __del__(self) -> None:
"""Redirect to self.close()."""
if not self.is_closed:
Expand All @@ -275,38 +305,26 @@ class DummyVectorEnv(BaseVectorEnv):
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""

def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
super().__init__(env_fns, DummyEnvWorker, **kwargs)


class SubprocVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on subprocess.
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""

def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=False)

super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, worker_fn, **kwargs)


class ShmemVectorEnv(BaseVectorEnv):
Expand All @@ -316,20 +334,14 @@ class ShmemVectorEnv(BaseVectorEnv):
.. seealso::
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
detailed explanation.
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""

def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=True)

super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, worker_fn, **kwargs)


class RayVectorEnv(BaseVectorEnv):
Expand All @@ -339,16 +351,10 @@ class RayVectorEnv(BaseVectorEnv):
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""

def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
try:
import ray
except ImportError as e:
Expand All @@ -357,4 +363,4 @@ def __init__(
) from e
if not ray.is_initialized():
ray.init()
super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, RayEnvWorker, **kwargs)
3 changes: 2 additions & 1 deletion tianshou/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from tianshou.utils.config import tqdm_config
from tianshou.utils.moving_average import MovAvg
from tianshou.utils.statistics import MovAvg, RunningMeanStd
from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger

__all__ = [
"MovAvg",
"RunningMeanStd",
"tqdm_config",
"BaseLogger",
"BasicLogger",
Expand Down
28 changes: 28 additions & 0 deletions tianshou/utils/moving_average.py → tianshou/utils/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,31 @@ def std(self) -> np.number:
if len(self.cache) == 0:
return 0
return np.std(self.cache)


class RunningMeanStd(object):
"""Calulates the running mean and std of a data stream.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
"""

def __init__(self) -> None:
self.mean, self.var = 0.0, 1.0
self.count = 0

def update(self, x: np.ndarray) -> None:
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
batch_mean, batch_var = np.mean(x, axis=0), np.var(x, axis=0)
batch_count = len(x)

delta = batch_mean - self.mean
total_count = self.count + batch_count

new_mean = self.mean + delta * batch_count / total_count
m_a = self.var * self.count
m_b = batch_var * batch_count
m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
new_var = m_2 / total_count

self.mean, self.var = new_mean, new_var
self.count = total_count

0 comments on commit 243ab43

Please sign in to comment.