From 441c2fb3e1ff118cc224f4292c9209cc50c32941 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Wed, 24 Mar 2021 15:09:12 +0800 Subject: [PATCH 1/5] refactor A2C/PPO, change behavior of value normalization --- tianshou/policy/modelfree/a2c.py | 57 ++++++++++++++++++------------ tianshou/policy/modelfree/pg.py | 6 ++-- tianshou/policy/modelfree/ppo.py | 59 ++++++++++++++++---------------- 3 files changed, 68 insertions(+), 54 deletions(-) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 3dd1e561a..03930c14c 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Type, Optional from tianshou.policy import PGPolicy -from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as class A2CPolicy(PGPolicy): @@ -25,8 +25,8 @@ class A2CPolicy(PGPolicy): Default to None. :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation. Default to 0.95. - :param bool reward_normalization: normalize the reward to Normal(0, 1). - Default to False. + :param bool reward_normalization: normalize estimated values to + have std close to 1. Default to False. :param int max_batchsize: the maximum size of the batch when computing GAE, depends on the size of available memory and the memory cost of the model; should be as large as possible within the memory constraint. @@ -72,22 +72,33 @@ def __init__( def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: - v_s_ = [] + v_s, v_s_ = [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): - v_s_.append(to_numpy(self.critic(b.obs_next))) - v_s_ = np.concatenate(v_s_, axis=0) - if self._rew_norm: # unnormalize v_s_ - v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean - unnormalized_returns, _ = self.compute_episodic_return( - batch, buffer, indice, v_s_=v_s_, + v_s.append(self.critic(b.obs)) + v_s_.append(self.critic(b.obs_next)) + batch.v_s = torch.cat(v_s, dim=0).flatten() # old value + v_s = to_numpy(batch.v_s) + v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten()) + # when normalizing values, we do not minus self.ret_rms.mean to be numerically + # consistent with OPENAI baselines' value normalization pipeline. Emperical + # study also shows that 'minus mean' will harm performances a tiny little bit + # due to unknown reasons(on Mujoco envs, not confident, though). + if self._rew_norm: # unnormalize v_s & v_s_ + v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + unnormalized_returns, advantages = self.compute_episodic_return( + batch, buffer, indice, v_s_, v_s, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: - batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ + batch.returns = unnormalized_returns / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns + batch.act = to_torch_as(batch.act, batch.v_s) + batch.returns = to_torch_as(batch.returns, batch.v_s) + batch.adv = to_torch_as(advantages, batch.v_s) return batch def learn( # type: ignore @@ -96,24 +107,26 @@ def learn( # type: ignore losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): - self.optim.zero_grad() + # calculate loss for actor dist = self(b).dist - v = self.critic(b.obs).flatten() - a = to_torch_as(b.act, v) - r = to_torch_as(b.returns, v) - log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) - a_loss = -(log_prob * (r - v).detach()).mean() - vf_loss = F.mse_loss(r, v) # type: ignore + log_prob = dist.log_prob(b.act).reshape(len(b.adv), -1).transpose(0, 1) + actor_loss = -(log_prob * b.adv).mean() + # calculate loss for critic + value = self.critic(b.obs).flatten() + vf_loss = F.mse_loss(b.returns, value) # type: ignore + # calculate regularization and overall loss ent_loss = dist.entropy().mean() - loss = a_loss + self._weight_vf * vf_loss - self._weight_ent * ent_loss + loss = actor_loss + self._weight_vf * vf_loss \ + - self._weight_ent * ent_loss + self.optim.zero_grad() loss.backward() if self._grad_norm is not None: + # clip large gradient nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), - max_norm=self._grad_norm, - ) + max_norm=self._grad_norm) self.optim.step() - actor_losses.append(a_loss.item()) + actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index ac06f1c00..6e525695c 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -116,9 +116,9 @@ def learn( # type: ignore result = self(b) dist = result.dist a = to_torch_as(b.act, result.act) - r = to_torch_as(b.returns, result.act) - log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) - loss = -(log_prob * r).mean() + ret = to_torch_as(b.returns, result.act) + log_prob = dist.log_prob(a).reshape(len(ret), -1).transpose(0, 1) + loss = -(log_prob * ret).mean() loss.backward() self.optim.step() losses.append(loss.item()) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index db7a22c6f..d65ce42e0 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -17,21 +17,21 @@ class PPOPolicy(A2CPolicy): :param dist_fn: distribution class for computing the action. :type dist_fn: Type[torch.distributions.Distribution] :param float discount_factor: in [0, 1]. Default to 0.99. - :param float max_grad_norm: clipping gradients in back propagation. - Default to None. :param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original paper. Default to 0.2. - :param float vf_coef: weight for value loss. Default to 0.5. - :param float ent_coef: weight for entropy loss. Default to 0.01. - :param float gae_lambda: in [0, 1], param for Generalized Advantage - Estimation. Default to 0.95. :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, where c > 1 is a constant indicating the lower bound. Default to 5.0 (set None if you do not want to use it). :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1. Default to True. - :param bool reward_normalization: normalize the returns and advantage to - Normal(0, 1). Default to False. + :param float vf_coef: weight for value loss. Default to 0.5. + :param float ent_coef: weight for entropy loss. Default to 0.01. + :param float max_grad_norm: clipping gradients in back propagation. + Default to None. + :param float gae_lambda: in [0, 1], param for Generalized Advantage + Estimation. Default to 0.95. + :param bool reward_normalization: normalize estimated values to have std close + to 1, also normalize the advantage to Normal(0, 1). Default to False. :param int max_batchsize: the maximum size of the batch when computing GAE, depends on the size of available memory and the memory cost of the model; should be as large as possible within the memory constraint. @@ -58,20 +58,13 @@ def __init__( critic: torch.nn.Module, optim: torch.optim.Optimizer, dist_fn: Type[torch.distributions.Distribution], - max_grad_norm: Optional[float] = None, eps_clip: float = 0.2, - vf_coef: float = 0.5, - ent_coef: float = 0.01, - gae_lambda: float = 0.95, dual_clip: Optional[float] = None, value_clip: bool = True, - max_batchsize: int = 256, **kwargs: Any, ) -> None: super().__init__( - actor, critic, optim, dist_fn, max_grad_norm=max_grad_norm, - vf_coef=vf_coef, ent_coef=ent_coef, gae_lambda=gae_lambda, - max_batchsize=max_batchsize, **kwargs) + actor, critic, optim, dist_fn, **kwargs) self._eps_clip = eps_clip assert dual_clip is None or dual_clip > 1.0, \ "Dual-clip PPO parameter should greater than 1.0." @@ -90,18 +83,22 @@ def process_fn( batch.v_s = torch.cat(v_s, dim=0).flatten() # old value v_s = to_numpy(batch.v_s) v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten()) + # when normalizing values, we do not minus self.ret_rms.mean to be numerically + # consistent with OPENAI baselines' value normalization pipeline. Emperical + # study also shows that 'minus mean' will harm performances a tiny little bit + # due to unknown reasons(on Mujoco envs, not confident, though). if self._rew_norm: # unnormalize v_s & v_s_ - v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean - v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean + v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) unnormalized_returns, advantages = self.compute_episodic_return( batch, buffer, indice, v_s_, v_s, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: - batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ + batch.returns = unnormalized_returns / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) mean, std = np.mean(advantages), np.std(advantages) - advantages = (advantages - mean) / std # per-batch norm + advantages = (advantages - mean) / std else: batch.returns = unnormalized_returns batch.act = to_torch_as(batch.act, batch.v_s) @@ -116,8 +113,8 @@ def learn( # type: ignore losses, clip_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): + # calculate loss for actor dist = self(b).dist - value = self.critic(b.obs).flatten() ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) surr1 = ratio * b.adv @@ -128,7 +125,8 @@ def learn( # type: ignore ).mean() else: clip_loss = -torch.min(surr1, surr2).mean() - clip_losses.append(clip_loss.item()) + # calculate loss for critic + value = self.critic(b.obs).flatten() if self._value_clip: v_clip = b.v_s + (value - b.v_s).clamp( -self._eps_clip, self._eps_clip) @@ -137,20 +135,23 @@ def learn( # type: ignore vf_loss = 0.5 * torch.max(vf1, vf2).mean() else: vf_loss = 0.5 * (b.returns - value).pow(2).mean() - vf_losses.append(vf_loss.item()) - e_loss = dist.entropy().mean() - ent_losses.append(e_loss.item()) + # calculate regularization and overall loss + ent_loss = dist.entropy().mean() loss = clip_loss + self._weight_vf * vf_loss \ - - self._weight_ent * e_loss - losses.append(loss.item()) + - self._weight_ent * ent_loss self.optim.zero_grad() loss.backward() if self._grad_norm is not None: + # clip large gradient nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), - self._grad_norm) + max_norm=self._grad_norm) self.optim.step() - # update learning rate if lr_scheduler is given + clip_losses.append(clip_loss.item()) + vf_losses.append(vf_loss.item()) + ent_losses.append(ent_loss.item()) + losses.append(loss.item()) + # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() From b5c6e08e04a3b42299bfdeee8f9d35a8a987b037 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 24 Mar 2021 17:20:23 +0800 Subject: [PATCH 2/5] update --- tianshou/policy/modelfree/a2c.py | 19 +++++++++---------- tianshou/policy/modelfree/ppo.py | 27 ++++++++++++--------------- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 03930c14c..478bac6fe 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -21,12 +21,12 @@ class A2CPolicy(PGPolicy): :param float discount_factor: in [0, 1]. Default to 0.99. :param float vf_coef: weight for value loss. Default to 0.5. :param float ent_coef: weight for entropy loss. Default to 0.01. - :param float max_grad_norm: clipping gradients in back propagation. - Default to None. - :param float gae_lambda: in [0, 1], param for Generalized Advantage - Estimation. Default to 0.95. - :param bool reward_normalization: normalize estimated values to - have std close to 1. Default to False. + :param float max_grad_norm: clipping gradients in back propagation. Default to + None. + :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + Default to 0.95. + :param bool reward_normalization: normalize estimated values to have std close to + 1. Default to False. :param int max_batchsize: the maximum size of the batch when computing GAE, depends on the size of available memory and the memory cost of the model; should be as large as possible within the memory constraint. @@ -82,8 +82,8 @@ def process_fn( v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten()) # when normalizing values, we do not minus self.ret_rms.mean to be numerically # consistent with OPENAI baselines' value normalization pipeline. Emperical - # study also shows that 'minus mean' will harm performances a tiny little bit - # due to unknown reasons(on Mujoco envs, not confident, though). + # study also shows that "minus mean" will harm performances a tiny little bit + # due to unknown reasons (on Mujoco envs, not confident, though). if self._rew_norm: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) @@ -120,8 +120,7 @@ def learn( # type: ignore - self._weight_ent * ent_loss self.optim.zero_grad() loss.backward() - if self._grad_norm is not None: - # clip large gradient + if self._grad_norm is not None: # clip large gradient nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index d65ce42e0..643c6f469 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -26,16 +26,15 @@ class PPOPolicy(A2CPolicy): Default to True. :param float vf_coef: weight for value loss. Default to 0.5. :param float ent_coef: weight for entropy loss. Default to 0.01. - :param float max_grad_norm: clipping gradients in back propagation. - Default to None. - :param float gae_lambda: in [0, 1], param for Generalized Advantage - Estimation. Default to 0.95. + :param float max_grad_norm: clipping gradients in back propagation. Default to + None. + :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + Default to 0.95. :param bool reward_normalization: normalize estimated values to have std close to 1, also normalize the advantage to Normal(0, 1). Default to False. :param int max_batchsize: the maximum size of the batch when computing GAE, - depends on the size of available memory and the memory cost of the - model; should be as large as possible within the memory constraint. - Default to 256. + depends on the size of available memory and the memory cost of the model; + should be as large as possible within the memory constraint. Default to 256. :param bool action_scaling: whether to map actions from range [-1, 1] to range [action_spaces.low, action_spaces.high]. Default to True. :param str action_bound_method: method to bound action to range [-1, 1], can be @@ -63,8 +62,7 @@ def __init__( value_clip: bool = True, **kwargs: Any, ) -> None: - super().__init__( - actor, critic, optim, dist_fn, **kwargs) + super().__init__(actor, critic, optim, dist_fn, **kwargs) self._eps_clip = eps_clip assert dual_clip is None or dual_clip > 1.0, \ "Dual-clip PPO parameter should greater than 1.0." @@ -85,8 +83,8 @@ def process_fn( v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten()) # when normalizing values, we do not minus self.ret_rms.mean to be numerically # consistent with OPENAI baselines' value normalization pipeline. Emperical - # study also shows that 'minus mean' will harm performances a tiny little bit - # due to unknown reasons(on Mujoco envs, not confident, though). + # study also shows that "minus mean" will harm performances a tiny little bit + # due to unknown reasons (on Mujoco envs, not confident, though). if self._rew_norm: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) @@ -98,7 +96,7 @@ def process_fn( np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) mean, std = np.mean(advantages), np.std(advantages) - advantages = (advantages - mean) / std + advantages = (advantages - mean) / std # per-batch norm else: batch.returns = unnormalized_returns batch.act = to_torch_as(batch.act, batch.v_s) @@ -141,8 +139,7 @@ def learn( # type: ignore - self._weight_ent * ent_loss self.optim.zero_grad() loss.backward() - if self._grad_norm is not None: - # clip large gradient + if self._grad_norm is not None: # clip large gradient nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm) @@ -151,7 +148,7 @@ def learn( # type: ignore vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) - # update learning rate if lr_scheduler is given + # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() From 48a7cb9a6621ea2abf30e666b901034e6d41295f Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 24 Mar 2021 17:21:30 +0800 Subject: [PATCH 3/5] fix ci --- tianshou/policy/modelfree/a2c.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 478bac6fe..bbf46221f 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -113,7 +113,7 @@ def learn( # type: ignore actor_loss = -(log_prob * b.adv).mean() # calculate loss for critic value = self.critic(b.obs).flatten() - vf_loss = F.mse_loss(b.returns, value) # type: ignore + vf_loss = F.mse_loss(b.returns, value) # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = actor_loss + self._weight_vf * vf_loss \ From 01c899a657268683b44b4fa925b5d20ab5e9f5d0 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 25 Mar 2021 08:14:14 +0800 Subject: [PATCH 4/5] a2c constantly 10s --- test/discrete/test_a2c_with_il.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index e9003ce8b..f7e3a86de 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -20,22 +20,22 @@ def get_args(): parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=3e-4) + parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--il-lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--il-step-per-epoch', type=int, default=1000) - parser.add_argument('--episode-per-collect', type=int, default=8) - parser.add_argument('--step-per-collect', type=int, default=8) - parser.add_argument('--update-per-step', type=float, default=0.125) + parser.add_argument('--episode-per-collect', type=int, default=16) + parser.add_argument('--step-per-collect', type=int, default=16) + parser.add_argument('--update-per-step', type=float, default=1 / 16) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128]) + nargs='*', default=[64, 64]) parser.add_argument('--imitation-hidden-sizes', type=int, nargs='*', default=[128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) From 4772e71ec38d2e55a06e1eea289442f1e4a4d440 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 25 Mar 2021 09:56:05 +0800 Subject: [PATCH 5/5] fix readme --- examples/mujoco/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mujoco/README.md b/examples/mujoco/README.md index 95ae47e05..7e719e248 100644 --- a/examples/mujoco/README.md +++ b/examples/mujoco/README.md @@ -43,7 +43,7 @@ This will start 10 experiments with different seeds. #### Example benchmark - + Other graphs can be found under `/examples/mujuco/benchmark/`