Skip to content

Commit

Permalink
#51 added DDPG example
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Weber committed May 19, 2021
1 parent 84c8f28 commit 6f86e2a
Showing 1 changed file with 5 additions and 43 deletions.
48 changes: 5 additions & 43 deletions experiments/issue51_new/stable_baselinesDDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
from os import makedirs
from typing import List

import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

import gym
import numpy as np
from stable_baselines3 import DDPG
Expand Down Expand Up @@ -114,47 +110,13 @@ def _on_step(self) -> bool:
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))


class CustomMPL(BaseFeaturesExtractor):

def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
super(CustomMPL, self).__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Linear(n_input_channels, 32),
nn.ReLU(),
nn.Linear(32, 64),
nn.ReLU(),
)

# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(
th.as_tensor(observation_space.sample()[None]).float()
).shape[1]

self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))


policy_kwargs = dict(
features_extractor_class=CustomMPL,
features_extractor_kwargs=dict(features_dim=128, net_arch=[32, 32]),
)

# policy_kwargs = dict(net_arch=dict(pi=[5, 5], qf=[10, 10]))
# policy_kwargs = dict( activation_fn=th.nn.LeakyReLU, net_arch=[32, 32])
model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/', policy_kwargs=policy_kwargs)
checkpoint_on_event = CheckpointCallback(save_freq=10000, save_path=f'{timestamp}/checkpoints/')
model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/')
checkpoint_on_event = CheckpointCallback(save_freq=100000, save_path=f'{timestamp}/checkpoints/')
record_env = RecordEnvCallback()
plot_callback = EveryNTimesteps(n_steps=10000, callback=record_env)
model.learn(total_timesteps=50000, callback=[checkpoint_on_event, plot_callback])
plot_callback = EveryNTimesteps(n_steps=50000, callback=record_env)
model.learn(total_timesteps=500000, callback=[checkpoint_on_event, plot_callback])

model.save('ddpg_CC2')
model.save('ddpg_CC')

del model # remove to demonstrate saving and loading

Expand Down

0 comments on commit 6f86e2a

Please sign in to comment.