From 4d92952a7be27a2ebf9bb59fccdd29bfa2c954b3 Mon Sep 17 00:00:00 2001 From: ChenDRAG <40993476+ChenDRAG@users.noreply.github.com> Date: Sun, 21 Mar 2021 16:45:50 +0800 Subject: [PATCH] Remap action to fit gym's action space (#313) Co-authored-by: Trinkle23897 --- examples/box2d/bipedal_hardcore_sac.py | 3 +- examples/box2d/mcc_sac.py | 4 +- examples/mujoco/mujoco_ddpg.py | 3 +- examples/mujoco/mujoco_sac.py | 3 +- examples/mujoco/mujoco_td3.py | 4 +- test/continuous/test_ddpg.py | 3 +- test/continuous/test_ppo.py | 5 +-- test/continuous/test_sac_with_il.py | 3 +- test/continuous/test_td3.py | 4 +- test/discrete/test_a2c_with_il.py | 3 +- test/discrete/test_pg.py | 3 +- test/discrete/test_ppo.py | 4 +- tianshou/data/collector.py | 8 +++- tianshou/policy/base.py | 42 +++++++++++++++++- tianshou/policy/modelfree/a2c.py | 7 +++ tianshou/policy/modelfree/ddpg.py | 21 ++++----- tianshou/policy/modelfree/discrete_sac.py | 8 ++-- tianshou/policy/modelfree/pg.py | 12 +++++- tianshou/policy/modelfree/ppo.py | 15 ++++--- tianshou/policy/modelfree/sac.py | 52 +++++++++++++---------- tianshou/policy/modelfree/td3.py | 15 ++++--- 21 files changed, 145 insertions(+), 77 deletions(-) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index d5a8f0577..5678277db 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -117,9 +117,8 @@ def test_sac_bipedal(args=get_args()): policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step) + estimation_step=args.n_step, action_space=env.action_space) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path)) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index f22c7846e..8fbc4257d 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -90,10 +90,10 @@ def test_sac(args=get_args()): policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=args.rew_norm, - exploration_noise=OUNoise(0.0, args.noise_std)) + exploration_noise=OUNoise(0.0, args.noise_std), + action_space=env.action_space) # collector train_collector = Collector( policy, train_envs, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 491d42375..9f5af3e4b 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -84,10 +84,9 @@ def test_ddpg(args=get_args()): critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, actor_optim, critic, critic_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), - estimation_step=args.n_step) + estimation_step=args.n_step, action_space=env.action_space) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load( diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 73ef4eb49..cf57318a8 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -97,9 +97,8 @@ def test_sac(args=get_args()): policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step) + estimation_step=args.n_step, action_space=env.action_space) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load( diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 9066cbee1..fd7c4eae6 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -95,11 +95,11 @@ def test_td3(args=get_args()): policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, - noise_clip=args.noise_clip, estimation_step=args.n_step) + noise_clip=args.noise_clip, estimation_step=args.n_step, + action_space=env.action_space) # load a previous policy if args.resume_path: diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 093aed196..5030abfe5 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -77,11 +77,10 @@ def test_ddpg(args=get_args()): critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, actor_optim, critic, critic_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), reward_normalization=args.rew_norm, - estimation_step=args.n_step) + estimation_step=args.n_step, action_space=env.action_space) # collector train_collector = Collector( policy, train_envs, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 762c58838..895d3c1f5 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -100,9 +100,8 @@ def dist(*logits): # dual_clip=args.dual_clip, # dual clip cause monotonically increasing log_std :) value_clip=args.value_clip, - # action_range=[env.action_space.low[0], env.action_space.high[0]],) - # if clip the action, ppo would not converge :) - gae_lambda=args.gae_lambda) + gae_lambda=args.gae_lambda, + action_space=env.action_space) # collector train_collector = Collector( policy, train_envs, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index ac533fcf4..8fb535ac7 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -87,10 +87,9 @@ def test_sac_with_il(args=get_args()): critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=args.rew_norm, - estimation_step=args.n_step) + estimation_step=args.n_step, action_space=env.action_space) # collector train_collector = Collector( policy, train_envs, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 86331e993..2e0674372 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -86,14 +86,14 @@ def test_td3(args=get_args()): critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, reward_normalization=args.rew_norm, - estimation_step=args.n_step) + estimation_step=args.n_step, + action_space=env.action_space) # collector train_collector = Collector( policy, train_envs, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 882cb440a..323d14848 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -80,7 +80,8 @@ def test_a2c_with_il(args=get_args()): policy = A2CPolicy( actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm) + max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm, + action_space=env.action_space) # collector train_collector = Collector( policy, train_envs, diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 6ebeb2686..83e3c1f6b 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -63,7 +63,8 @@ def test_pg(args=get_args()): optim = torch.optim.Adam(net.parameters(), lr=args.lr) dist = torch.distributions.Categorical policy = PGPolicy(net, optim, dist, args.gamma, - reward_normalization=args.rew_norm) + reward_normalization=args.rew_norm, + action_space=env.action_space) # collector train_collector = Collector( policy, train_envs, diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 5821e7be8..11428dc0d 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -85,11 +85,11 @@ def test_ppo(args=get_args()): eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - action_range=None, gae_lambda=args.gae_lambda, reward_normalization=args.rew_norm, dual_clip=args.dual_clip, - value_clip=args.value_clip) + value_clip=args.value_clip, + action_space=env.action_space) # collector train_collector = Collector( policy, train_envs, diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 3a1b05d26..819342863 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -219,8 +219,10 @@ def collect( act = self.policy.exploration_noise(act, self.data) self.data.update(policy=policy, act=act) + # get bounded and remapped actions first (not saved into buffer) + action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids) + obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids) self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: @@ -419,8 +421,10 @@ def collect( _alloc_by_keys_diff(whole_data, self.data, self.env_num, False) whole_data[ready_env_ids] = self.data # lots of overhead + # get bounded and remapped actions first (not saved into buffer) + action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids) + obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids) # change self.data here because ready_env_ids has changed ready_env_ids = np.array([i["env_id"] for i in info]) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 730ee28b0..1d420173b 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -53,14 +53,20 @@ class BasePolicy(ABC, nn.Module): def __init__( self, - observation_space: gym.Space = None, - action_space: gym.Space = None + observation_space: Optional[gym.Space] = None, + action_space: Optional[gym.Space] = None, + action_scaling: bool = False, + action_bound_method: str = "", ) -> None: super().__init__() self.observation_space = observation_space self.action_space = action_space self.agent_id = 0 self.updating = False + self.action_scaling = action_scaling + # can be one of ("clip", "tanh", ""), empty string means no bounding + assert action_bound_method in ("", "clip", "tanh") + self.action_bound_method = action_bound_method self._compile() def set_agent_id(self, agent_id: int) -> None: @@ -114,6 +120,38 @@ def forward( """ pass + def map_action(self, act: Union[Batch, np.ndarray]) -> Union[Batch, np.ndarray]: + """Map raw network output to action range in gym's env.action_space. + + This function is called in :meth:`~tianshou.data.Collector.collect` and only + affects action sending to env. Remapped action will not be stored in buffer + and thus can be viewed as a part of env (a black box action transformation). + + Action mapping includes 2 standard procedures: bounding and scaling. Bounding + procedure expects original action range is (-inf, inf) and maps it to [-1, 1], + while scaling procedure expects original action range is (-1, 1) and maps it + to [action_space.low, action_space.high]. Bounding procedure is applied first. + + :param act: a data batch or numpy.ndarray which is the action taken by + policy.forward. + + :return: action in the same form of input "act" but remap to the target action + space. + """ + if isinstance(self.action_space, gym.spaces.Box) and \ + isinstance(act, np.ndarray): + # currently this action mapping only supports np.ndarray action + if self.action_bound_method == "clip": + act = np.clip(act, -1.0, 1.0) + elif self.action_bound_method == "tanh": + act = np.tanh(act) + if self.action_scaling: + assert np.all(act >= -1.0) and np.all(act <= 1.0), \ + "action scaling only accepts raw action range = [-1, 1]" + low, high = self.action_space.low, self.action_space.high + act = low + (high - low) * (act + 1.0) / 2.0 + return act + def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index f79682789..79fb308de 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -31,6 +31,13 @@ class A2CPolicy(PGPolicy): depends on the size of available memory and the memory cost of the model; should be as large as possible within the memory constraint. Default to 256. + :param bool action_scaling: whether to map actions from range [-1, 1] to range + [action_spaces.low, action_spaces.high]. Default to True. + :param str action_bound_method: method to bound action to range [-1, 1], can be + either "clip" (for simply clipping the action), "tanh" (for applying tanh + squashing) for now, or empty string for no bounding. Default to "clip". + :param Optional[gym.Space] action_space: env's action space, mandatory if you want + to use option "action_scaling" or "action_bound_method". Default to None. .. seealso:: diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 87a06c53e..7d582fbae 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -16,8 +16,6 @@ class DDPGPolicy(BasePolicy): :param torch.optim.Optimizer actor_optim: the optimizer for actor network. :param torch.nn.Module critic: the critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic_optim: the optimizer for critic network. - :param action_range: the action range (minimum, maximum). - :type action_range: Tuple[float, float] :param float tau: param for soft update of the target network. Default to 0.005. :param float gamma: discount factor, in [0, 1]. Default to 0.99. :param BaseNoise exploration_noise: the exploration noise, @@ -25,6 +23,13 @@ class DDPGPolicy(BasePolicy): :param bool reward_normalization: normalize the reward to Normal(0, 1), Default to False. :param int estimation_step: the number of steps to look ahead. Default to 1. + :param bool action_scaling: whether to map actions from range [-1, 1] to range + [action_spaces.low, action_spaces.high]. Default to True. + :param str action_bound_method: method to bound action to range [-1, 1], can be + either "clip" (for simply clipping the action), "tanh" (for applying tanh + squashing) for now, or empty string for no bounding. Default to "clip". + :param Optional[gym.Space] action_space: env's action space, mandatory if you want + to use option "action_scaling" or "action_bound_method". Default to None. .. seealso:: @@ -38,15 +43,17 @@ def __init__( actor_optim: Optional[torch.optim.Optimizer], critic: Optional[torch.nn.Module], critic_optim: Optional[torch.optim.Optimizer], - action_range: Tuple[float, float], tau: float = 0.005, gamma: float = 0.99, exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), reward_normalization: bool = False, estimation_step: int = 1, + action_scaling: bool = True, + action_bound_method: str = "clip", **kwargs: Any, ) -> None: - super().__init__(**kwargs) + super().__init__(action_scaling=action_scaling, + action_bound_method=action_bound_method, **kwargs) if actor is not None and actor_optim is not None: self.actor: torch.nn.Module = actor self.actor_old = deepcopy(actor) @@ -62,9 +69,6 @@ def __init__( assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]" self._gamma = gamma self._noise = exploration_noise - self._range = action_range - self._action_bias = (action_range[0] + action_range[1]) / 2.0 - self._action_scale = (action_range[1] - action_range[0]) / 2.0 # it is only a little difference to use GaussianNoise # self.noise = OUNoise() self._rew_norm = reward_normalization @@ -128,8 +132,6 @@ def forward( model = getattr(self, model) obs = batch[input] actions, h = model(obs, state=state, info=batch.info) - actions += self._action_bias - actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h) @staticmethod @@ -168,5 +170,4 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray: if self._noise: act = act + self._noise(act.shape) - act = act.clip(self._range[0], self._range[1]) return act diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index a53bbbbf8..d4364db7d 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -49,10 +49,10 @@ def __init__( estimation_step: int = 1, **kwargs: Any, ) -> None: - super().__init__(actor, actor_optim, critic1, critic1_optim, critic2, - critic2_optim, (-np.inf, np.inf), tau, gamma, alpha, - reward_normalization, estimation_step, - **kwargs) + super().__init__( + actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, + tau, gamma, alpha, reward_normalization, estimation_step, + action_scaling=False, action_bound_method="", **kwargs) self._alpha: Union[float, torch.Tensor] def forward( # type: ignore diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 080ba70a2..742423aa5 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -15,6 +15,13 @@ class PGPolicy(BasePolicy): :param dist_fn: distribution class for computing the action. :type dist_fn: Type[torch.distributions.Distribution] :param float discount_factor: in [0, 1]. Default to 0.99. + :param bool action_scaling: whether to map actions from range [-1, 1] to range + [action_spaces.low, action_spaces.high]. Default to True. + :param str action_bound_method: method to bound action to range [-1, 1], can be + either "clip" (for simply clipping the action), "tanh" (for applying tanh + squashing) for now, or empty string for no bounding. Default to "clip". + :param Optional[gym.Space] action_space: env's action space, mandatory if you want + to use option "action_scaling" or "action_bound_method". Default to None. .. seealso:: @@ -29,9 +36,12 @@ def __init__( dist_fn: Type[torch.distributions.Distribution], discount_factor: float = 0.99, reward_normalization: bool = False, + action_scaling: bool = True, + action_bound_method: str = "clip", **kwargs: Any, ) -> None: - super().__init__(**kwargs) + super().__init__(action_scaling=action_scaling, + action_bound_method=action_bound_method, **kwargs) if model is not None: self.model: torch.nn.Module = model self.optim = optim diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 953829195..8f96ce38c 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,7 +1,7 @@ import torch import numpy as np from torch import nn -from typing import Any, Dict, List, Type, Tuple, Union, Optional +from typing import Any, Dict, List, Type, Union, Optional from tianshou.policy import PGPolicy from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as @@ -23,8 +23,6 @@ class PPOPolicy(PGPolicy): paper. Default to 0.2. :param float vf_coef: weight for value loss. Default to 0.5. :param float ent_coef: weight for entropy loss. Default to 0.01. - :param action_range: the action range (minimum, maximum). - :type action_range: (float, float) :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation. Default to 0.95. :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, @@ -38,6 +36,13 @@ class PPOPolicy(PGPolicy): depends on the size of available memory and the memory cost of the model; should be as large as possible within the memory constraint. Default to 256. + :param bool action_scaling: whether to map actions from range [-1, 1] to range + [action_spaces.low, action_spaces.high]. Default to True. + :param str action_bound_method: method to bound action to range [-1, 1], can be + either "clip" (for simply clipping the action), "tanh" (for applying tanh + squashing) for now, or empty string for no bounding. Default to "clip". + :param Optional[gym.Space] action_space: env's action space, mandatory if you want + to use option "action_scaling" or "action_bound_method". Default to None. .. seealso:: @@ -56,7 +61,6 @@ def __init__( eps_clip: float = 0.2, vf_coef: float = 0.5, ent_coef: float = 0.01, - action_range: Optional[Tuple[float, float]] = None, gae_lambda: float = 0.95, dual_clip: Optional[float] = None, value_clip: bool = True, @@ -69,7 +73,6 @@ def __init__( self._eps_clip = eps_clip self._weight_vf = vf_coef self._weight_ent = ent_coef - self._range = action_range self.actor = actor self.critic = critic self._batch = max_batchsize @@ -135,8 +138,6 @@ def forward( else: dist = self.dist_fn(logits) act = dist.sample() - if self._range: - act = act.clamp(self._range[0], self._range[1]) return Batch(logits=logits, act=act, state=h, dist=dist) def learn( # type: ignore diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index ac81cbfc2..6fa5911f5 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -6,7 +6,7 @@ from tianshou.policy import DDPGPolicy from tianshou.exploration import BaseNoise -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch, ReplayBuffer, to_torch_as class SACPolicy(DDPGPolicy): @@ -21,8 +21,6 @@ class SACPolicy(DDPGPolicy): :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic2_optim: the optimizer for the second critic network. - :param action_range: the action range (minimum, maximum). - :type action_range: Tuple[float, float] :param float tau: param for soft update of the target network. Default to 0.005. :param float gamma: discount factor, in [0, 1]. Default to 0.99. :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy @@ -36,6 +34,13 @@ class SACPolicy(DDPGPolicy): :param bool deterministic_eval: whether to use deterministic action (mean of Gaussian policy) instead of stochastic action sampled by the policy. Default to True. + :param bool action_scaling: whether to map actions from range [-1, 1] to range + [action_spaces.low, action_spaces.high]. Default to True. + :param str action_bound_method: method to bound action to range [-1, 1], can be + either "clip" (for simply clipping the action), "tanh" (for applying tanh + squashing) for now, or empty string for no bounding. Default to "tanh". + :param Optional[gym.Space] action_space: env's action space, mandatory if you want + to use option "action_scaling" or "action_bound_method". Default to None. .. seealso:: @@ -51,7 +56,6 @@ def __init__( critic1_optim: torch.optim.Optimizer, critic2: torch.nn.Module, critic2_optim: torch.optim.Optimizer, - action_range: Tuple[float, float], tau: float = 0.005, gamma: float = 0.99, alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, @@ -59,11 +63,13 @@ def __init__( estimation_step: int = 1, exploration_noise: Optional[BaseNoise] = None, deterministic_eval: bool = True, + action_bound_method: str = "tanh", **kwargs: Any, ) -> None: - super().__init__(None, None, None, None, action_range, tau, gamma, - exploration_noise, reward_normalization, - estimation_step, **kwargs) + super().__init__( + None, None, None, None, tau, gamma, exploration_noise, + reward_normalization, estimation_step, + action_bound_method=action_bound_method, **kwargs) self.actor, self.actor_optim = actor, actor_optim self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() @@ -110,24 +116,26 @@ def forward( # type: ignore assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: - x = logits[0] + act = logits[0] else: - x = dist.rsample() - y = torch.tanh(x) - act = y * self._action_scale + self._action_bias - # __eps is used to avoid log of zero/negative number. - y = self._action_scale * (1 - y.pow(2)) + self.__eps - # Compute logprob from Gaussian, and then apply correction for Tanh squashing. - # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. - # in appendix C to get some understanding of this equation. - log_prob = dist.log_prob(x).unsqueeze(-1) - log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) - + act = dist.rsample() + log_prob = dist.log_prob(act).unsqueeze(-1) + if self.action_bound_method == "tanh" and self.action_space is not None: + # apply correction for Tanh squashing when computing logprob from Gaussian + # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. + # in appendix C to get some understanding of this equation. + if self.action_scaling: + action_scale = to_torch_as( + (self.action_space.high - self.action_space.low) / 2.0, act) + else: + action_scale = 1.0 # type: ignore + squashed_action = torch.tanh(act) + log_prob = log_prob - torch.log( + action_scale * (1 - squashed_action.pow(2)) + self.__eps + ).sum(-1, keepdim=True) return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) - def _target_q( - self, buffer: ReplayBuffer, indice: np.ndarray - ) -> torch.Tensor: + def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} obs_next_result = self(batch, input='obs_next') a_ = obs_next_result.act diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 09f288ff6..96843a3f8 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,7 +1,7 @@ import torch import numpy as np from copy import deepcopy -from typing import Any, Dict, Tuple, Optional +from typing import Any, Dict, Optional from tianshou.policy import DDPGPolicy from tianshou.data import Batch, ReplayBuffer @@ -20,8 +20,6 @@ class TD3Policy(DDPGPolicy): :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic2_optim: the optimizer for the second critic network. - :param action_range: the action range (minimum, maximum). - :type action_range: Tuple[float, float] :param float tau: param for soft update of the target network. Default to 0.005. :param float gamma: discount factor, in [0, 1]. Default to 0.99. :param float exploration_noise: the exploration noise, add to the action. @@ -34,6 +32,13 @@ class TD3Policy(DDPGPolicy): Default to 0.5. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. + :param bool action_scaling: whether to map actions from range [-1, 1] to range + [action_spaces.low, action_spaces.high]. Default to True. + :param str action_bound_method: method to bound action to range [-1, 1], can be + either "clip" (for simply clipping the action), "tanh" (for applying tanh + squashing) for now, or empty string for no bounding. Default to "clip". + :param Optional[gym.Space] action_space: env's action space, mandatory if you want + to use option "action_scaling" or "action_bound_method". Default to None. .. seealso:: @@ -49,7 +54,6 @@ def __init__( critic1_optim: torch.optim.Optimizer, critic2: torch.nn.Module, critic2_optim: torch.optim.Optimizer, - action_range: Tuple[float, float], tau: float = 0.005, gamma: float = 0.99, exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), @@ -60,7 +64,7 @@ def __init__( estimation_step: int = 1, **kwargs: Any, ) -> None: - super().__init__(actor, actor_optim, None, None, action_range, tau, gamma, + super().__init__(actor, actor_optim, None, None, tau, gamma, exploration_noise, reward_normalization, estimation_step, **kwargs) self.critic1, self.critic1_old = critic1, deepcopy(critic1) @@ -98,7 +102,6 @@ def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: if self._noise_clip > 0.0: noise = noise.clamp(-self._noise_clip, self._noise_clip) a_ += noise - a_ = a_.clamp(self._range[0], self._range[1]) target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_))