diff --git a/LICENSE b/LICENSE index 6a7aa81e6..322a77c33 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2020 Tianshou contributors +Copyright (c) 2022 Tianshou contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/docs/index.rst b/docs/index.rst index b131ff402..009e9eb96 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,7 +13,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DQNPolicy` `Double DQN `_ * :class:`~tianshou.policy.DQNPolicy` `Dueling DQN `_ * :class:`~tianshou.policy.C51Policy` `Categorical DQN `_ -* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN `_ +* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN `_ * :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN `_ * :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network `_ * :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function `_ diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 592f0916e..537986cf6 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -64,14 +64,15 @@ def __init__( super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") - env = DummyVectorEnv([lambda: env]) - self.env = env - self.env_num = len(env) + self.env = DummyVectorEnv([lambda: env]) # type: ignore + else: + self.env = env # type: ignore + self.env_num = len(self.env) self.exploration_noise = exploration_noise self._assign_buffer(buffer) self.policy = policy self.preprocess_fn = preprocess_fn - self._action_space = env.action_space + self._action_space = self.env.action_space # avoid creating attribute outside __init__ self.reset(False) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index de752516a..68e3f9648 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -6,7 +6,7 @@ from pettingzoo.utils.wrappers import BaseWrapper -class PettingZooEnv(AECEnv, gym.Env, ABC): +class PettingZooEnv(AECEnv, ABC): """The interface for petting zoo environments. Multi-agent environments must be wrapped as diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index c668109b6..044ecaaa4 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -12,7 +12,7 @@ from tianshou.utils import RunningMeanStd -class BaseVectorEnv(gym.Env): +class BaseVectorEnv(object): """Base class for vectorized environments wrapper. Usage: @@ -196,6 +196,7 @@ def _assert_id(self, id: Union[List[int], np.ndarray]) -> None: assert i in self.ready_id, \ f"Can only interact with ready environments {self.ready_id}." + # TODO: compatible issue with reset -> (obs, info) def reset( self, id: Optional[Union[int, List[int], np.ndarray]] = None ) -> np.ndarray: diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 958f6e907..be873861c 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -31,9 +31,9 @@ def wait( # type: ignore def send(self, action: Optional[np.ndarray]) -> None: if action is None: - self.result = self.env.reset() + self.result = self.env.reset() # type: ignore else: - self.result = self.env.step(action) + self.result = self.env.step(action) # type: ignore def seed(self, seed: Optional[int] = None) -> List[int]: super().seed(seed) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 779b78e34..c2119ab50 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -53,7 +53,7 @@ def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]: assert isinstance(space.spaces, tuple) return tuple([_setup_buf(t) for t in space.spaces]) else: - return ShArray(space.dtype, space.shape) + return ShArray(space.dtype, space.shape) # type: ignore def _worker( diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 03469dcc8..ac5b3e07a 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -122,9 +122,8 @@ def forward( # type: ignore # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. if self.action_scaling and self.action_space is not None: - action_scale = to_torch_as( - (self.action_space.high - self.action_space.low) / 2.0, act - ) + low, high = self.action_space.low, self.action_space.high # type: ignore + action_scale = to_torch_as((high - low) / 2.0, act) else: action_scale = 1.0 # type: ignore squashed_action = torch.tanh(act)