Skip to content

Commit

Permalink
fix nstep rew reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Jan 4, 2021
1 parent 3695f12 commit 22fa78a
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,15 @@ def _nstep_return(
"""Numba speedup: 0.3s -> 0.15s."""
target_shape = target_q.shape
bsz = target_shape[0]
# change rew/target_q to 2d array
# change target_q to 2d array
target_q = target_q.reshape(bsz, -1)
rew = rew.reshape(-1, 1) # assume reward is a scalar
returns = np.zeros(target_q.shape)
gammas = np.full(indice.shape, n_step)
for n in range(n_step - 1, -1, -1):
now = (indice + n) % buf_len
gammas[done[now] > 0] = n
returns[done[now] > 0] = 0.0
returns = (rew[now] - mean) / std + gamma * returns
returns = (rew[now].reshape(-1, 1) - mean) / std + gamma * returns
target_q[gammas != n_step] = 0.0
gammas = gammas.reshape(-1, 1)
target_q = target_q * (gamma ** gammas) + returns
Expand Down

0 comments on commit 22fa78a

Please sign in to comment.