Skip to content

Commit

Permalink
Fix reward bug in the CleanRL example (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn authored Jan 13, 2022
1 parent 057f6b8 commit 50e8689
Showing 1 changed file with 95 additions and 74 deletions.
169 changes: 95 additions & 74 deletions examples/cleanrl_examples/ppo_atari_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import random
import time
from collections import deque
from distutils.util import strtobool

import gym
Expand All @@ -40,171 +41,162 @@ def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument(
'--exp-name',
"--exp-name",
type=str,
default=os.path.basename(__file__).rstrip(".py"),
help='the name of this experiment'
help="the name of this experiment"
)
parser.add_argument(
'--gym-id',
"--gym-id",
type=str,
default="Pong-v5",
help='the id of the gym environment'
help="the id of the gym environment"
)
parser.add_argument(
'--learning-rate',
"--learning-rate",
type=float,
default=2.5e-4,
help='the learning rate of the optimizer'
help="the learning rate of the optimizer"
)
parser.add_argument(
'--seed', type=int, default=1, help='seed of the experiment'
"--seed", type=int, default=1, help="seed of the experiment"
)
parser.add_argument(
'--total-timesteps',
"--total-timesteps",
type=int,
default=10000000,
help='total timesteps of the experiments'
help="total timesteps of the experiments"
)
parser.add_argument(
'--torch-deterministic',
"--torch-deterministic",
type=lambda x: bool(strtobool(x)),
default=True,
nargs='?',
nargs="?",
const=True,
help='if toggled, `torch.backends.cudnn.deterministic=False`'
help="if toggled, `torch.backends.cudnn.deterministic=False`"
)
parser.add_argument(
'--cuda',
"--cuda",
type=lambda x: bool(strtobool(x)),
default=True,
nargs='?',
nargs="?",
const=True,
help='if toggled, cuda will be enabled by default'
help="if toggled, cuda will be enabled by default"
)
parser.add_argument(
'--track',
"--track",
type=lambda x: bool(strtobool(x)),
default=False,
nargs='?',
nargs="?",
const=True,
help='if toggled, this experiment will be tracked with Weights and Biases'
help="if toggled, this experiment will be tracked with Weights and Biases"
)
parser.add_argument(
'--wandb-project-name',
"--wandb-project-name",
type=str,
default="cleanRL",
help="the wandb's project name"
)
parser.add_argument(
'--wandb-entity',
"--wandb-entity",
type=str,
default=None,
help="the entity (team) of wandb's project"
)
parser.add_argument(
'--capture-video',
type=lambda x: bool(strtobool(x)),
default=False,
nargs='?',
const=True,
help='whether to capture videos of the agent performances '
'(check out `videos` folder)'
)

# Algorithm specific arguments
parser.add_argument(
'--num-envs',
"--num-envs",
type=int,
default=32,
help='the number of parallel game environments'
default=8,
help="the number of parallel game environments"
)
parser.add_argument(
'--num-steps',
"--num-steps",
type=int,
default=128,
help='the number of steps to run in each environment per policy rollout'
help="the number of steps to run in each environment per policy rollout"
)
parser.add_argument(
'--anneal-lr',
"--anneal-lr",
type=lambda x: bool(strtobool(x)),
default=True,
nargs='?',
nargs="?",
const=True,
help="Toggle learning rate annealing for policy and value networks"
help="toggle learning rate annealing for policy and value networks"
)
parser.add_argument(
'--gae',
"--gae",
type=lambda x: bool(strtobool(x)),
default=True,
nargs='?',
nargs="?",
const=True,
help='Use GAE for advantage computation'
help="use GAE for advantage computation"
)
parser.add_argument(
'--gamma', type=float, default=0.99, help='the discount factor gamma'
"--gamma", type=float, default=0.99, help="the discount factor gamma"
)
parser.add_argument(
'--gae-lambda',
"--gae-lambda",
type=float,
default=0.95,
help='the lambda for the general advantage estimation'
help="the lambda for the general advantage estimation"
)
parser.add_argument(
'--num-minibatches',
"--num-minibatches",
type=int,
default=4,
help='the number of mini-batches'
help="the number of mini-batches"
)
parser.add_argument(
'--update-epochs',
"--update-epochs",
type=int,
default=4,
help="the K epochs to update the policy"
)
parser.add_argument(
'--norm-adv',
"--norm-adv",
type=lambda x: bool(strtobool(x)),
default=True,
nargs='?',
nargs="?",
const=True,
help="Toggle advantages normalization"
help="toggle advantages normalization"
)
parser.add_argument(
'--clip-coef',
"--clip-coef",
type=float,
default=0.1,
help="the surrogate clipping coefficient"
)
parser.add_argument(
'--clip-vloss',
"--clip-vloss",
type=lambda x: bool(strtobool(x)),
default=True,
nargs='?',
nargs="?",
const=True,
help='Toggle whether or not to use a clipped loss '
'for the value function, as per the paper.'
help="toggle whether or not to use a clipped loss "
"for the value function, as per the paper."
)
parser.add_argument(
'--ent-coef', type=float, default=0.01, help="coefficient of the entropy"
"--ent-coef", type=float, default=0.01, help="coefficient of the entropy"
)
parser.add_argument(
'--vf-coef',
"--vf-coef",
type=float,
default=0.5,
help="coefficient of the value function"
)
parser.add_argument(
'--max-grad-norm',
"--max-grad-norm",
type=float,
default=0.5,
help='the maximum norm for the gradient clipping'
help="the maximum norm for the gradient clipping"
)
parser.add_argument(
'--target-kl',
"--target-kl",
type=float,
default=None,
help='the target KL divergence threshold'
help="the target KL divergence threshold"
)
args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
Expand All @@ -220,25 +212,37 @@ def __init__(self, env, deque_size=100):
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
self.is_vector_env = True
# get if the env has lives
self.has_lives = False
env.reset()
info = env.step(np.zeros(self.num_envs, dtype=int))[-1]
if info["lives"].sum() > 0:
self.has_lives = True
print("env has lives")

def reset(self, **kwargs):
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.lives = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations

def step(self, action):
observations, rewards, dones, infos = super(RecordEpisodeStatistics,
self).step(action)
self.episode_returns += rewards
self.episode_returns += infos["reward"]
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
self.episode_returns *= (1 - dones)
self.episode_lengths *= (1 - dones)
all_lives_exhausted = infos["lives"] == 0
if self.has_lives:
self.episode_returns *= 1 - all_lives_exhausted
self.episode_lengths *= 1 - all_lives_exhausted
else:
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
Expand Down Expand Up @@ -270,7 +274,9 @@ def __init__(self, envs):
layer_init(nn.Linear(64 * 7 * 7, 512)),
nn.ReLU(),
)
self.actor = layer_init(nn.Linear(512, envs.action_space.n), std=0.01)
self.actor = layer_init(
nn.Linear(512, envs.single_action_space.n), std=0.01
)
self.critic = layer_init(nn.Linear(512, 1), std=1)

def get_value(self, x):
Expand Down Expand Up @@ -318,9 +324,16 @@ def get_action_and_value(self, x, action=None):
)

# env setup
envs = envpool.make(args.gym_id, env_type="gym", num_envs=args.num_envs)
envs = envpool.make(
args.gym_id,
env_type="gym",
num_envs=args.num_envs,
episodic_life=True,
reward_clip=True,
)
envs.num_envs = args.num_envs
# envs.is_vector_env = True
envs.single_action_space = envs.action_space
envs.single_observation_space = envs.observation_space
envs = RecordEpisodeStatistics(envs)
assert isinstance(
envs.action_space, gym.spaces.Discrete
Expand All @@ -331,15 +344,16 @@ def get_action_and_value(self, x, action=None):

# ALGO Logic: Storage setup
obs = torch.zeros(
(args.num_steps, args.num_envs) + envs.observation_space.shape
(args.num_steps, args.num_envs) + envs.single_observation_space.shape
).to(device)
actions = torch.zeros(
(args.num_steps, args.num_envs) + envs.action_space.shape
(args.num_steps, args.num_envs) + envs.single_action_space.shape
).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
avg_returns = deque(maxlen=20)

# TRY NOT TO MODIFY: start the game
global_step = 0
Expand Down Expand Up @@ -375,8 +389,12 @@ def get_action_and_value(self, x, action=None):
).to(device)

for idx, d in enumerate(done):
if d:
if d and info["lives"][idx] == 0:
print(f"global_step={global_step}, episodic_return={info['r'][idx]}")
avg_returns.append(info["r"][idx])
writer.add_scalar(
"charts/avg_episodic_return", np.average(avg_returns), global_step
)
writer.add_scalar(
"charts/episodic_return", info["r"][idx], global_step
)
Expand Down Expand Up @@ -415,14 +433,14 @@ def get_action_and_value(self, x, action=None):
advantages = returns - values

# flatten the batch
b_obs = obs.reshape((-1,) + envs.observation_space.shape)
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + envs.action_space.shape)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)

# Optimizaing the policy and value network
# Optimizing the policy and value network
b_inds = np.arange(args.batch_size)
clipfracs = []
for _epoch in range(args.update_epochs):
Expand All @@ -440,7 +458,7 @@ def get_action_and_value(self, x, action=None):

with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
# old_approx_kl = (-logratio).mean()
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfracs += [
((ratio - 1.0).abs() > args.clip_coef).float().mean().item()
Expand Down Expand Up @@ -499,6 +517,9 @@ def get_action_and_value(self, x, action=None):
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar(
"losses/old_approx_kl", old_approx_kl.item(), global_step
)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
Expand Down

0 comments on commit 50e8689

Please sign in to comment.