Skip to content

Commit

Permalink
Refactor PG algorithm and change behavior of compute_episodic_return (
Browse files Browse the repository at this point in the history
thu-ml#319)

- simplify code
- apply value normalization (global) and adv norm (per-batch) in on-policy algorithms
  • Loading branch information
ChenDRAG authored Mar 23, 2021
1 parent 5ac4102 commit 44ce78b
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 192 deletions.
24 changes: 12 additions & 12 deletions test/base/test_returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def test_episodic_returns(size=2560):
for b in batch:
b.obs = b.act = 1
buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1)
returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
assert np.allclose(batch.returns, ans)
assert np.allclose(returns, ans)
buf.reset()
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
Expand All @@ -41,9 +41,9 @@ def test_episodic_returns(size=2560):
for b in batch:
b.obs = b.act = 1
buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1)
returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
assert np.allclose(batch.returns, ans)
assert np.allclose(returns, ans)
buf.reset()
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
Expand All @@ -52,9 +52,9 @@ def test_episodic_returns(size=2560):
for b in batch:
b.obs = b.act = 1
buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1)
returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
assert np.allclose(batch.returns, ans)
assert np.allclose(returns, ans)
buf.reset()
batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
Expand All @@ -64,12 +64,12 @@ def test_episodic_returns(size=2560):
b.obs = b.act = 1
buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
returns = np.array([
returns, _ = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
ground_truth = np.array([
454.8344, 376.1143, 291.298, 200.,
464.5610, 383.1085, 295.387, 201.,
474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns)
assert np.allclose(returns, ground_truth)
buf.reset()
batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
Expand All @@ -82,12 +82,12 @@ def test_episodic_returns(size=2560):
b.obs = b.act = 1
buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
returns = np.array([
returns, _ = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
ground_truth = np.array([
454.0109, 375.2386, 290.3669, 199.01,
462.9138, 381.3571, 293.5248, 199.02,
474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns)
assert np.allclose(returns, ground_truth)

if __name__ == '__main__':
buf = ReplayBuffer(size)
Expand Down
3 changes: 2 additions & 1 deletion test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def test_ppo(args=get_args()):
def dist(*logits):
return Independent(Normal(*logits), 1)
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,
actor, critic, optim, dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_a2c_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def test_a2c_with_il(args=get_args()):
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
actor, critic, optim, dist,
discount_factor=args.gamma, gae_lambda=args.gae_lambda,
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm,
action_space=env.action_space)
Expand Down
9 changes: 7 additions & 2 deletions test/discrete/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.95)
Expand All @@ -27,7 +27,7 @@ def get_args():
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
Expand Down Expand Up @@ -65,6 +65,11 @@ def test_pg(args=get_args()):
policy = PGPolicy(net, optim, dist, args.gamma,
reward_normalization=args.rew_norm,
action_space=env.action_space)
for m in net.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def test_ppo(args=get_args()):
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,
actor, critic, optim, dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
Expand Down
37 changes: 10 additions & 27 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn
from numba import njit
from abc import ABC, abstractmethod
from typing import Any, Dict, Union, Optional, Callable
from typing import Any, Dict, Tuple, Union, Optional, Callable

from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy

Expand Down Expand Up @@ -254,14 +254,14 @@ def compute_episodic_return(
buffer: ReplayBuffer,
indice: np.ndarray,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
v_s: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
rew_norm: bool = False,
) -> Batch:
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute returns over given batch.
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
to calculate q function/reward to go of given batch.
to calculate q/advantage value of given batch.
:param Batch batch: a data batch which contains several episodes of data in
sequential order. Mind that the end of each finished episode of batch
Expand All @@ -273,25 +273,23 @@ def compute_episodic_return(
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage Estimation,
should be in [0, 1]. Default to 0.95.
:param bool rew_norm: normalize the reward to Normal(0, 1). Default to False.
:return: a Batch. The result will be stored in batch.returns as a numpy
array with shape (bsz, ).
:return: two numpy arrays (returns, advantage) with each shape (bsz, ).
"""
rew = batch.rew
if v_s_ is None:
assert np.isclose(gae_lambda, 1.0)
v_s_ = np.zeros_like(rew)
else:
v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice)
v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten())

end_flag = batch.done.copy()
end_flag[np.isin(indice, buffer.unfinished_index())] = True
returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda)
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
returns = (returns - returns.mean()) / returns.std()
batch.returns = returns
return batch
advantage = _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda)
returns = advantage + v_s
# normalization varies from each policy, so we don't do it here
return returns, advantage

@staticmethod
def compute_nstep_return(
Expand Down Expand Up @@ -355,8 +353,6 @@ def _compile(self) -> None:
i64 = np.array([[0, 1]], dtype=np.int64)
_gae_return(f64, f64, f64, b, 0.1, 0.1)
_gae_return(f32, f32, f64, b, 0.1, 0.1)
_episodic_return(f64, f64, b, 0.1, 0.1)
_episodic_return(f32, f64, b, 0.1, 0.1)
_nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1)


Expand All @@ -379,19 +375,6 @@ def _gae_return(
return returns


@njit
def _episodic_return(
v_s_: np.ndarray,
rew: np.ndarray,
end_flag: np.ndarray,
gamma: float,
gae_lambda: float,
) -> np.ndarray:
"""Numba speedup: 4.1s -> 0.057s."""
v_s = np.roll(v_s_, 1)
return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s


@njit
def _nstep_return(
rew: np.ndarray,
Expand Down
59 changes: 16 additions & 43 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from torch import nn
import torch.nn.functional as F
from typing import Any, Dict, List, Type, Union, Optional
from typing import Any, Dict, List, Type, Optional

from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
Expand Down Expand Up @@ -53,69 +53,42 @@ def __init__(
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution],
discount_factor: float = 0.99,
vf_coef: float = 0.5,
ent_coef: float = 0.01,
max_grad_norm: Optional[float] = None,
gae_lambda: float = 0.95,
reward_normalization: bool = False,
max_batchsize: int = 256,
**kwargs: Any
) -> None:
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self.actor = actor
super().__init__(actor, optim, dist_fn, **kwargs)
self.critic = critic
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
self._lambda = gae_lambda
self._weight_vf = vf_coef
self._weight_ent = ent_coef
self._grad_norm = max_grad_norm
self._batch = max_batchsize
self._rew_norm = reward_normalization

def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
if self._lambda in [0.0, 1.0]:
return self.compute_episodic_return(
batch, buffer, indice,
None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
v_s_ = []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_.append(to_numpy(self.critic(b.obs_next)))
v_ = np.concatenate(v_, axis=0)
return self.compute_episodic_return(
batch, buffer, indice, v_,
gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm)

def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, h = self.actor(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
v_s_.append(to_numpy(self.critic(b.obs_next)))
v_s_ = np.concatenate(v_s_, axis=0)
if self._rew_norm: # unnormalize v_s_
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
unnormalized_returns, _ = self.compute_episodic_return(
batch, buffer, indice, v_s_=v_s_,
gamma=self._gamma, gae_lambda=self._lambda)
if self._rew_norm:
batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns)
else:
dist = self.dist_fn(logits)
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
batch.returns = unnormalized_returns
return batch

def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
Expand Down
Loading

0 comments on commit 44ce78b

Please sign in to comment.