Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor PG algorithm and change behavior of compute_episodic_return #319

Merged
merged 12 commits into from
Mar 23, 2021
Merged
24 changes: 12 additions & 12 deletions test/base/test_returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def test_episodic_returns(size=2560):
for b in batch:
b.obs = b.act = 1
buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1)
returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
assert np.allclose(batch.returns, ans)
assert np.allclose(returns, ans)
buf.reset()
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
Expand All @@ -41,9 +41,9 @@ def test_episodic_returns(size=2560):
for b in batch:
b.obs = b.act = 1
buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1)
returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
assert np.allclose(batch.returns, ans)
assert np.allclose(returns, ans)
buf.reset()
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
Expand All @@ -52,9 +52,9 @@ def test_episodic_returns(size=2560):
for b in batch:
b.obs = b.act = 1
buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1)
returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
assert np.allclose(batch.returns, ans)
assert np.allclose(returns, ans)
buf.reset()
batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
Expand All @@ -64,12 +64,12 @@ def test_episodic_returns(size=2560):
b.obs = b.act = 1
buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
returns = np.array([
returns, _ = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
ground_truth = np.array([
454.8344, 376.1143, 291.298, 200.,
464.5610, 383.1085, 295.387, 201.,
474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns)
assert np.allclose(returns, ground_truth)
buf.reset()
batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
Expand All @@ -82,12 +82,12 @@ def test_episodic_returns(size=2560):
b.obs = b.act = 1
buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
returns = np.array([
returns, _ = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
ground_truth = np.array([
454.0109, 375.2386, 290.3669, 199.01,
462.9138, 381.3571, 293.5248, 199.02,
474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns)
assert np.allclose(returns, ground_truth)

if __name__ == '__main__':
buf = ReplayBuffer(size)
Expand Down
3 changes: 2 additions & 1 deletion test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def test_ppo(args=get_args()):
def dist(*logits):
return Independent(Normal(*logits), 1)
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,
actor, critic, optim, dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_a2c_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def test_a2c_with_il(args=get_args()):
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
actor, critic, optim, dist,
discount_factor=args.gamma, gae_lambda=args.gae_lambda,
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm,
action_space=env.action_space)
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def test_ppo(args=get_args()):
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,
actor, critic, optim, dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
Expand Down
33 changes: 10 additions & 23 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,14 @@ def compute_episodic_return(
buffer: ReplayBuffer,
indice: np.ndarray,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
v_s: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
rew_norm: bool = False,
) -> Batch:
"""Compute returns over given batch.

Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
to calculate q function/reward to go of given batch.
to calculate Q&A function/reward to go of given batch.
ChenDRAG marked this conversation as resolved.
Show resolved Hide resolved

:param Batch batch: a data batch which contains several episodes of data in
sequential order. Mind that the end of each finished episode of batch
Expand All @@ -273,8 +273,8 @@ def compute_episodic_return(
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage Estimation,
should be in [0, 1]. Default to 0.95.
:param bool rew_norm: normalize the reward to Normal(0, 1). Default to False.

# TODO change doc
:return: a Batch. The result will be stored in batch.returns as a numpy
array with shape (bsz, ).
"""
Expand All @@ -284,14 +284,16 @@ def compute_episodic_return(
v_s_ = np.zeros_like(rew)
else:
v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice)
if v_s is None:
v_s = np.roll(v_s_, 1)
else:
v_s = to_numpy(v_s.flatten())

end_flag = batch.done.copy()
end_flag[np.isin(indice, buffer.unfinished_index())] = True
returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda)
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
returns = (returns - returns.mean()) / returns.std()
batch.returns = returns
return batch
advantage = _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda)
returns = advantage + v_s
return returns, advantage

@staticmethod
def compute_nstep_return(
Expand Down Expand Up @@ -355,8 +357,6 @@ def _compile(self) -> None:
i64 = np.array([[0, 1]], dtype=np.int64)
_gae_return(f64, f64, f64, b, 0.1, 0.1)
_gae_return(f32, f32, f64, b, 0.1, 0.1)
_episodic_return(f64, f64, b, 0.1, 0.1)
_episodic_return(f32, f64, b, 0.1, 0.1)
_nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1)


Expand All @@ -379,19 +379,6 @@ def _gae_return(
return returns


@njit
def _episodic_return(
v_s_: np.ndarray,
rew: np.ndarray,
end_flag: np.ndarray,
gamma: float,
gae_lambda: float,
) -> np.ndarray:
"""Numba speedup: 4.1s -> 0.057s."""
v_s = np.roll(v_s_, 1)
return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s


@njit
def _nstep_return(
rew: np.ndarray,
Expand Down
59 changes: 16 additions & 43 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from torch import nn
import torch.nn.functional as F
from typing import Any, Dict, List, Type, Union, Optional
from typing import Any, Dict, List, Type, Optional

from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
Expand Down Expand Up @@ -53,69 +53,42 @@ def __init__(
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution],
discount_factor: float = 0.99,
vf_coef: float = 0.5,
ent_coef: float = 0.01,
max_grad_norm: Optional[float] = None,
gae_lambda: float = 0.95,
reward_normalization: bool = False,
max_batchsize: int = 256,
**kwargs: Any
) -> None:
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self.actor = actor
super().__init__(actor, optim, dist_fn, **kwargs)
self.critic = critic
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
self._lambda = gae_lambda
self._weight_vf = vf_coef
self._weight_ent = ent_coef
self._grad_norm = max_grad_norm
self._batch = max_batchsize
self._rew_norm = reward_normalization

def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
if self._lambda in [0.0, 1.0]:
return self.compute_episodic_return(
batch, buffer, indice,
None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
v_s_ = []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_.append(to_numpy(self.critic(b.obs_next)))
v_ = np.concatenate(v_, axis=0)
return self.compute_episodic_return(
batch, buffer, indice, v_,
gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm)

def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any
) -> Batch:
"""Compute action over the given batch data.

:return: A :class:`~tianshou.data.Batch` which has 4 keys:

* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.

.. seealso::

Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, h = self.actor(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
v_s_.append(to_numpy(self.critic(b.obs_next)))
v_s_ = np.concatenate(v_s_, axis=0)
if self._rew_norm:
# unnormalize v_s_
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
un_normalized_returns, _ = self.compute_episodic_return(
ChenDRAG marked this conversation as resolved.
Show resolved Hide resolved
batch, buffer, indice, v_s_, gamma=self._gamma, gae_lambda=self._lambda)
ChenDRAG marked this conversation as resolved.
Show resolved Hide resolved
if self._rew_norm:
batch.returns = (un_normalized_returns - self.ret_rms.mean) / \
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(un_normalized_returns)
else:
dist = self.dist_fn(logits)
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
batch.returns = un_normalized_returns
return batch

def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
Expand Down
56 changes: 20 additions & 36 deletions tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.utils import RunningMeanStd


class PGPolicy(BasePolicy):
"""Implementation of Vanilla Policy Gradient.
"""Implementation of REINFORCE algorithm.

:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
Expand Down Expand Up @@ -45,14 +46,15 @@ def __init__(
) -> None:
super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs)
if model is not None:
self.model: torch.nn.Module = model
self.actor = model
self.optim = optim
self.lr_scheduler = lr_scheduler
self.dist_fn = dist_fn
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
self._gamma = discount_factor
self._rew_norm = reward_normalization
self.ret_rms = RunningMeanStd()
self._eps = 1e-8

def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
Expand All @@ -65,11 +67,16 @@ def process_fn(
where :math:`T` is the terminal time step, :math:`\gamma` is the
discount factor, :math:`\gamma \in [0, 1]`.
"""
# batch.returns = self._vanilla_returns(batch)
# batch.returns = self._vectorized_returns(batch)
return self.compute_episodic_return(
batch, buffer, indice, gamma=self._gamma,
gae_lambda=1.0, rew_norm=self._rew_norm)
v_s_ = np.full(indice.shape, self.ret_rms.mean)
un_normalized_returns, _ = self.compute_episodic_return(
batch, buffer, indice, v_s_, gamma=self._gamma, gae_lambda=1.0)
if self._rew_norm:
batch.returns = (un_normalized_returns - self.ret_rms.mean) / \
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(un_normalized_returns)
else:
batch.returns = un_normalized_returns
return batch

def forward(
self,
Expand All @@ -91,7 +98,7 @@ def forward(
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, h = self.model(batch.obs, state=state, info=batch.info)
logits, h = self.actor(batch.obs, state=state)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
Expand All @@ -106,9 +113,10 @@ def learn( # type: ignore
for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True):
self.optim.zero_grad()
dist = self(b).dist
a = to_torch_as(b.act, dist.logits)
r = to_torch_as(b.returns, dist.logits)
result = self(b)
dist = result.dist
a = to_torch_as(b.act, result.act)
r = to_torch_as(b.returns, result.act)
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
loss = -(log_prob * r).mean()
loss.backward()
Expand All @@ -119,27 +127,3 @@ def learn( # type: ignore
self.lr_scheduler.step()

return {"loss": losses}

# def _vanilla_returns(self, batch):
# returns = batch.rew[:]
# last = 0
# for i in range(len(returns) - 1, -1, -1):
# if not batch.done[i]:
# returns[i] += self._gamma * last
# last = returns[i]
# return returns

# def _vectorized_returns(self, batch):
# # according to my tests, it is slower than _vanilla_returns
# # import scipy.signal
# convolve = np.convolve
# # convolve = scipy.signal.convolve
# rew = batch.rew[::-1]
# batch_size = len(rew)
# gammas = self._gamma ** np.arange(batch_size)
# c = convolve(rew, gammas)[:batch_size]
# T = np.where(batch.done[::-1])[0]
# d = np.zeros_like(rew)
# d[T] += c[T] - rew[T]
# d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
# return (c - convolve(d, gammas)[:batch_size])[::-1]
Loading