Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match PPG implementation #186

Merged
merged 32 commits into from
May 28, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
419041d
added nit changes from ppg code
dipamc May 14, 2022
2e1190b
change observation buffer to uint8
dipamc May 14, 2022
86f5be7
sample full rollouts
dipamc May 15, 2022
beff293
minor device fix
dipamc May 15, 2022
4cb85d5
update optimizer settings
dipamc May 16, 2022
d6ee26b
add ppg documentation
May 18, 2022
fea4531
update mkdocs
dipamc May 18, 2022
20f15da
update images to png for codespell errors
dipamc May 18, 2022
6c3cb05
trigger CI
vwxyzjn May 18, 2022
631ab96
Minor format change
vwxyzjn May 18, 2022
d961d0f
format by running `pre-commit`
vwxyzjn May 18, 2022
4cff11d
removes trailing space
vwxyzjn May 18, 2022
fb9c832
Add an extra note
vwxyzjn May 19, 2022
31bb5c4
argument names and documentation changes
dipamc May 23, 2022
ed66604
add capture video
dipamc May 23, 2022
1610191
add experiment report
dipamc May 25, 2022
51c6aac
Merge branch 'master' into ppg-dev
vwxyzjn May 27, 2022
a4342f8
Update documentation
vwxyzjn May 27, 2022
3d4711c
Quick css fix
vwxyzjn May 27, 2022
b780521
Update documentation
vwxyzjn May 27, 2022
9c4edf8
Fix documentation for PPO
vwxyzjn May 27, 2022
23cd48e
Add benchmark commands
vwxyzjn May 27, 2022
8e4f977
Add benchmark commands
vwxyzjn May 27, 2022
72e8cce
add metrics section
dipamc May 27, 2022
aa695c1
Add more docs
vwxyzjn May 27, 2022
0564584
Quick fix on ddpg docs
vwxyzjn May 27, 2022
a08039e
Add procgen test cases
vwxyzjn May 27, 2022
31a175c
Update CI
vwxyzjn May 27, 2022
f063a7b
test CI
vwxyzjn May 27, 2022
60df2c8
test ci
vwxyzjn May 27, 2022
e70c71a
Update tests
vwxyzjn May 27, 2022
6ebaaae
normalization axis documentation
dipamc May 28, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
name: pre-commit

on:
push:
branches: [ master ]
pull_request:
branches: [ '*' ]
branches: [ master ]
jobs:
build:
runs-on: ubuntu-latest
Expand Down
130 changes: 91 additions & 39 deletions cleanrl/ppg_procgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def parse_args():
help="the lambda for the general advantage estimation")
parser.add_argument("--num-minibatches", type=int, default=8,
help="the number of mini-batches")
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
dipamc marked this conversation as resolved.
Show resolved Hide resolved
help="Toggles advantages normalization")
parser.add_argument("--norm-adv-ppg", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
dipamc marked this conversation as resolved.
Show resolved Hide resolved
help="Full batch advantage normalization as used in PPG code")
parser.add_argument("--clip-coef", type=float, default=0.2,
help="the surrogate clipping coefficient")
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
Expand All @@ -82,32 +84,51 @@ def parse_args():
help="E_aux:the K epochs to update the policy")
parser.add_argument("--beta-clone", type=float, default=1.0,
help="the behavior cloning coefficient")
parser.add_argument("--n-aux-minibatch", type=int, default=16 * 32,
parser.add_argument("--aux-num-rollouts", type=int, default=4,
dipamc marked this conversation as resolved.
Show resolved Hide resolved
help="the number of mini batch in the auxiliary phase")
parser.add_argument("--n-aux-grad-accum", type=int, default=1,
help="the number of gradient accumulation in mini batch")
args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.aux_batch_size = int(args.batch_size * args.n_iteration)
args.aux_minibatch_size = int(args.aux_batch_size // (args.n_aux_minibatch * args.n_aux_grad_accum))
args.aux_batch_rollouts = int(args.num_envs * args.n_iteration)
assert args.v_value == 1, "Multiple value epoch (v_value != 1) is not supported yet"
# fmt: on
return args


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
def layer_init_normed(layer, norm_dim, scale=1.0):
with torch.no_grad():
layer.weight.data *= scale / layer.weight.norm(dim=norm_dim, p=2, keepdim=True)
layer.bias *= 0
return layer


def flatten01(arr):
return arr.reshape((-1, *arr.shape[2:]))


def unflatten01(arr, targetshape):
return arr.reshape((*targetshape, *arr.shape[1:]))


def flatten_unflatten_test():
a = torch.rand(400, 30, 100, 100, 5)
b = flatten01(a)
c = unflatten01(b, a.shape[:2])
assert torch.equal(a, c)


# taken from https://github.com/AIcrowd/neurips2020-procgen-starter-kit/blob/142d09586d2272a17f44481a115c4bd817cf6a94/models/impala_cnn_torch.py
class ResidualBlock(nn.Module):
def __init__(self, channels):
def __init__(self, channels, scale):
super().__init__()
self.conv0 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)
self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)
# scale = (1/3**0.5 * 1/2**0.5)**0.5 # For default IMPALA CNN this is the final scale value in the PPG code
scale = np.sqrt(scale)
conv0 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)
self.conv0 = layer_init_normed(conv0, norm_dim=(1, 2, 3), scale=scale)
conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)
self.conv1 = layer_init_normed(conv1, norm_dim=(1, 2, 3), scale=scale)

def forward(self, x):
inputs = x
Expand All @@ -119,13 +140,16 @@ def forward(self, x):


class ConvSequence(nn.Module):
def __init__(self, input_shape, out_channels):
def __init__(self, input_shape, out_channels, scale):
super().__init__()
self._input_shape = input_shape
self._out_channels = out_channels
self.conv = nn.Conv2d(in_channels=self._input_shape[0], out_channels=self._out_channels, kernel_size=3, padding=1)
self.res_block0 = ResidualBlock(self._out_channels)
self.res_block1 = ResidualBlock(self._out_channels)
conv = nn.Conv2d(in_channels=self._input_shape[0], out_channels=self._out_channels, kernel_size=3, padding=1)
self.conv = layer_init_normed(conv, norm_dim=(1, 2, 3), scale=1.0)
nblocks = 2 # Set to the number of residual blocks
scale = scale / np.sqrt(nblocks)
self.res_block0 = ResidualBlock(self._out_channels, scale=scale)
self.res_block1 = ResidualBlock(self._out_channels, scale=scale)

def forward(self, x):
x = self.conv(x)
Expand All @@ -146,20 +170,25 @@ def __init__(self, envs):
h, w, c = envs.single_observation_space.shape
shape = (c, h, w)
conv_seqs = []
for out_channels in [16, 32, 32]:
conv_seq = ConvSequence(shape, out_channels)
chans = [16, 32, 32]
scale = 1 / np.sqrt(len(chans)) # Not fully sure about the logic behind this but its used in PPG code
for out_channels in chans:
conv_seq = ConvSequence(shape, out_channels, scale=scale)
shape = conv_seq.get_output_shape()
conv_seqs.append(conv_seq)

encodertop = nn.Linear(in_features=shape[0] * shape[1] * shape[2], out_features=256)
encodertop = layer_init_normed(encodertop, norm_dim=1, scale=1.4)
conv_seqs += [
nn.Flatten(),
nn.ReLU(),
nn.Linear(in_features=shape[0] * shape[1] * shape[2], out_features=256),
encodertop,
nn.ReLU(),
]
self.network = nn.Sequential(*conv_seqs)
self.actor = layer_init(nn.Linear(256, envs.single_action_space.n), std=0.01)
self.critic = layer_init(nn.Linear(256, 1), std=1)
self.aux_critic = layer_init(nn.Linear(256, 1), std=1)
self.actor = layer_init_normed(nn.Linear(256, envs.single_action_space.n), norm_dim=1, scale=0.1)
self.critic = layer_init_normed(nn.Linear(256, 1), norm_dim=1, scale=0.1)
self.aux_critic = layer_init_normed(nn.Linear(256, 1), norm_dim=1, scale=0.1)

def get_action_and_value(self, x, action=None):
hidden = self.network(x.permute((0, 3, 1, 2)) / 255.0) # "bhwc" -> "bchw"
Expand Down Expand Up @@ -202,6 +231,8 @@ def get_pi(self, x):
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

flatten_unflatten_test() # Try not to mess with the flatten unflatten logic

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
Expand All @@ -222,7 +253,7 @@ def get_pi(self, x):
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

agent = Agent(envs).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-8)

# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
Expand All @@ -231,8 +262,10 @@ def get_pi(self, x):
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
aux_obs = torch.zeros((args.num_steps * args.num_envs * args.n_iteration,) + envs.single_observation_space.shape)
aux_returns = torch.zeros((args.num_steps * args.num_envs * args.n_iteration,))
aux_obs = torch.zeros(
(args.num_steps, args.aux_batch_rollouts) + envs.single_observation_space.shape, dtype=torch.uint8
) # Saves lot system RAM
aux_returns = torch.zeros((args.num_steps, args.aux_batch_rollouts))

# TRY NOT TO MODIFY: start the game
global_step = 0
Expand Down Expand Up @@ -312,6 +345,10 @@ def get_pi(self, x):
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)

# PPG code does full batch advantage normalization
if args.norm_adv_ppg:
b_advantages = (b_advantages - b_advantages.mean()) / (b_advantages.std() + 1e-8)

# Optimizing the policy and value network
b_inds = np.arange(args.batch_size)
clipfracs = []
Expand Down Expand Up @@ -383,45 +420,60 @@ def get_pi(self, x):
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# PPG Storage:
storage_slice = slice(args.num_steps * args.num_envs * (update - 1), args.num_steps * args.num_envs * update)
aux_obs[storage_slice] = b_obs.cpu().clone()
aux_returns[storage_slice] = b_returns.cpu().clone()
# PPG Storage - Rollouts are saved without flattening for sampling full rollouts later:
storage_slice = slice(args.num_envs * (update - 1), args.num_envs * update)
aux_obs[:, storage_slice] = obs.cpu().clone().to(torch.uint8)
aux_returns[:, storage_slice] = returns.cpu().clone()

# AUXILIARY PHASE
old_agent = Agent(envs).to(device)
old_agent.load_state_dict(agent.state_dict())
aux_inds = np.arange(
args.aux_batch_size,
)
aux_inds = np.arange(args.aux_batch_rollouts)

# Build the old policy on the aux buffer before distilling to the network
aux_pi = torch.zeros((args.num_steps, args.aux_batch_rollouts, envs.single_action_space.n))
for i, start in enumerate(range(0, args.aux_batch_rollouts, args.aux_num_rollouts)):
end = start + args.aux_num_rollouts
aux_minibatch_ind = aux_inds[start:end]
m_aux_obs = aux_obs[:, aux_minibatch_ind].to(torch.float32).to(device)
m_obs_shape = m_aux_obs.shape
m_aux_obs = flatten01(m_aux_obs)
with torch.no_grad():
pi_logits = agent.get_pi(m_aux_obs).logits.cpu().clone()
aux_pi[:, aux_minibatch_ind] = unflatten01(pi_logits, m_obs_shape[:2])
del m_aux_obs

for auxiliary_update in range(1, args.e_auxiliary + 1):
print(f"aux epoch {auxiliary_update}")
np.random.shuffle(aux_inds)
for i, start in enumerate(range(0, args.aux_batch_size, args.aux_minibatch_size)):
end = start + args.aux_minibatch_size
for i, start in enumerate(range(0, args.aux_batch_rollouts, args.aux_num_rollouts)):
end = start + args.aux_num_rollouts
aux_minibatch_ind = aux_inds[start:end]
try:
m_aux_obs = aux_obs[aux_minibatch_ind].to(device)
m_aux_returns = aux_returns[aux_minibatch_ind].to(device)
m_aux_obs = aux_obs[:, aux_minibatch_ind].to(device)
m_obs_shape = m_aux_obs.shape
m_aux_obs = flatten01(m_aux_obs) # Sample full rollouts for PPG instead of random indexes
m_aux_returns = aux_returns[:, aux_minibatch_ind].to(torch.float32).to(device)
m_aux_returns = flatten01(m_aux_returns)

new_pi, new_values, new_aux_values = agent.get_pi_value_and_aux_value(m_aux_obs)

new_values = new_values.view(-1)
new_aux_values = new_aux_values.view(-1)
with torch.no_grad():
old_pi = old_agent.get_pi(m_aux_obs)
old_pi_logits = flatten01(aux_pi[:, aux_minibatch_ind]).to(device)
old_pi = Categorical(logits=old_pi_logits)
kl_loss = td.kl_divergence(old_pi, new_pi).mean()

real_value_loss = 0.5 * ((new_values - m_aux_returns) ** 2).mean()
aux_value_loss = 0.5 * ((new_aux_values - m_aux_returns) ** 2).mean()
joint_loss = aux_value_loss + args.beta_clone * kl_loss

optimizer.zero_grad()
loss = (joint_loss + real_value_loss) / args.n_aux_grad_accum
loss.backward()

if (i + 1) % args.n_aux_grad_accum == 0:
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
optimizer.zero_grad() # This cannot be outside, else gradients won't accumulate

except RuntimeError:
raise Exception(
"if running out of CUDA memory, try a higher --n-aux-grad-accum, which trades more time for less gpu memory"
Expand Down
3 changes: 3 additions & 0 deletions docs/rl-algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ Below are the implemented algorithms and their brief descriptions.
- [x] Twin Delayed Deep Deterministic Policy Gradient (TD3)
* [td3_continuous_action.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py)
* For continuous action space.
- [x] Phasic Policy Gradient (PPG)
* [ppg_procgen.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppg_procgen.py)
* PPG implementation for Procgen
1 change: 1 addition & 0 deletions docs/rl-algorithms/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
| ✅ [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) | :material-github: [`sac_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py) |
| ✅ [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) | :material-github: [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action.py) |
| ✅ [Twin Delayed Deep Deterministic Policy Gradient (TD3)](https://arxiv.org/pdf/1802.09477.pdf) | :material-github: [`td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py) |
| ✅ [Phasic Policy Gradient (PPG)](https://arxiv.org/abs/2009.04416) | :material-github: [`ppg_procgen.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppg_procgen.py) |

93 changes: 93 additions & 0 deletions docs/rl-algorithms/ppg.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Phasic Policy Gradient (PPG)

## Overview

PPG is a DRL algorithm that separates policy and value function training by introducing an auxiliary phase. The training proceeds by running PPO during the policy phase, saving all the experience in a replay buffer. Then the replay buffer is used to train the value function. This makes the algorithm considerably slower than PPO, but improves sample efficiency on Procgen benchmark.

Original paper:

* [Phasic Policy Gradient](https://arxiv.org/abs/2009.04416)

Reference resources:

* [Code for the paper "Phasic Policy Gradient"](https://github.com/openai/phasic-policy-gradient) - by original authors from OpenAI

The original code has multiple code level details that are not mentioned in the paper. We found these changes to be important for reproducing the results claimed by the paper.

## Implemented Variants


| Variants Implemented | Description |
| ----------- | ----------- |
| :material-github: [`ppg_procgen.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppg_procgen.py), :material-file-document: [docs](/rl-algorithms/ppg/#ppg_procgenpy) | For classic control tasks like `CartPole-v1`. |

Below are our single-file implementations of PPG:

## `ppg_procgen.py`

`ppg_procgen.py` works with the Procgen benchmark, which uses 64x64 RGB image observations, and discrete actions

### Usage

```bash
poetry install -E procgen
python cleanrl/ppg_procgen.py --help
python cleanrl/ppg_procgen.py --env-id "bigfish"
```

### Implementation details

`ppg_procgen.py` includes the <TODO> level implementation details that are different from PPO:

1. Full rollout sampling during auxiliary phase - (:material-github: [phasic_policy_gradient/ppg.py#L173](https://github.com/openai/phasic-policy-gradient/blob/c789b00be58aa704f7223b6fc8cd28a5aaa2e101/phasic_policy_gradient/ppg.py#L173)) - Instead of randomly sampling observations over the entire auxiliary buffer, PPG samples full rullouts from the buffer (Sets of 256 steps). This full rollout sampling is only done during the auxiliary phase. Note that the rollouts will still be at random starting points because PPO truncates the rollouts per env. This change gives a decent performance boost.

1. Batch level advantage normalization - PPG normalizes the full batch of advantage values before PPO updates instead of advantage normalization on each minibatch. (:material-github: [phasic_policy_gradient/ppo.py#L70](https://github.com/openai/phasic-policy-gradient/blob/c789b00be58aa704f7223b6fc8cd28a5aaa2e101/phasic_policy_gradient/ppo.py#L70))

1. Normalized network initialization - (:material-github: [phasic_policy_gradient/impala_cnn.py#L64](https://github.com/openai/phasic-policy-gradient/blob/c789b00be58aa704f7223b6fc8cd28a5aaa2e101/phasic_policy_gradient/impala_cnn.py#L64)) - PPG uses normalized initialization for all layers, with different scales.
* Original PPO used orthogonal initialization of only the Policy head and Value heads with scale of 0.01 and 1. respectively.
* For PPG
* All weights are initialized with the default torch initialization (Kaiming Uniform)
* Each layer’s weights are divided by the L2 norm of the weights along the (which axis?), and multiplied by a scale factor.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clarify "which axis" here.

* Scale factors for different layers
* Value head, Policy head, Auxiliary value head - 0.1
* Fully connected layer after last conv later - 1.4
* Convolutional layers - Approximately 0.638
1. The Adam Optimizer's Epsilon Parameter -(:material-github: [phasic_policy_gradient/ppg.py#L239](https://github.com/openai/phasic-policy-gradient/blob/c789b00be58aa704f7223b6fc8cd28a5aaa2e101/phasic_policy_gradient/ppg.py#L239)) - Set to torch default of 1e-8 instead of 1e-5 which is used in PPO.


### Experiment results

Below are the average episodic returns for `ppg_procgen.py`, and comparison with `ppg_procgen.py` on 25M timesteps.

| Environment | `ppg_procgen.py` | `ppg_procgen.py` |
| ----------- | ----------- | ----------- |
| Bigfish (easy) | 27.670 ± 9.523 | 21.605 ± 7.996 |
| Starpilot (easy) | 39.086 ± 11.042 | 34.025 ± 12.535 |

Learning curves:

<div class="grid-container">

<img src="../ppg/bigfish-easy-ppg-ppo.png">

<img src="../ppg/starpilot-easy-ppg-ppo.png">

</div>

Tracked experiments and game play videos:

To be added

### Extra notes

- All the default hyperparameters from the original PPG implementation are used. Except setting 64 for the number of environments.
- The original PPG paper does not report results on easy environments, hence more hyperparameter tuning can give better results.
- Skipping every alternate auxiliary phase gives similar performance on easy environments while saving compute.
- Normalized network initialization scheme seems to matter a lot, but using layernorm with orthogonal initialization also works.
- Using mixed precision for auxiliary phase also works well to save compute, but using on policy phase makes training unstable.


### Differences from the original PPG code

- The original PPG code supports LSTM whereas the CleanRL code does not.
- The original PPG code uses separate optimizers for policy and auxiliary phase, but we do not implement this as we found it to not make too much difference.
dipamc marked this conversation as resolved.
Show resolved Hide resolved
Binary file added docs/rl-algorithms/ppg/bigfish-easy-ppg-ppo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ nav:
- rl-algorithms/ddpg.md
- rl-algorithms/sac.md
- rl-algorithms/td3.md
- rl-algorithms/ppg.md
- Open RL Benchmark: open-rl-benchmark.md
- Advanced:
- advanced/resume-training.md
Expand Down