diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index b875f6075..6e172dfd7 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -312,16 +312,15 @@ def _nstep_return( """Numba speedup: 0.3s -> 0.15s.""" target_shape = target_q.shape bsz = target_shape[0] - # change rew/target_q to 2d array + # change target_q to 2d array target_q = target_q.reshape(bsz, -1) - rew = rew.reshape(-1, 1) # assume reward is a scalar returns = np.zeros(target_q.shape) gammas = np.full(indice.shape, n_step) for n in range(n_step - 1, -1, -1): now = (indice + n) % buf_len gammas[done[now] > 0] = n returns[done[now] > 0] = 0.0 - returns = (rew[now] - mean) / std + gamma * returns + returns = (rew[now].reshape(-1, 1) - mean) / std + gamma * returns target_q[gammas != n_step] = 0.0 gammas = gammas.reshape(-1, 1) target_q = target_q * (gamma ** gammas) + returns