diff --git a/experiments/issue51_new/stable_baselinesDDPG.py b/experiments/issue51_new/stable_baselinesDDPG.py index 9c8b47ef..5713cfd2 100644 --- a/experiments/issue51_new/stable_baselinesDDPG.py +++ b/experiments/issue51_new/stable_baselinesDDPG.py @@ -2,10 +2,6 @@ from os import makedirs from typing import List -import torch as th -import torch.nn as nn -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor - import gym import numpy as np from stable_baselines3 import DDPG @@ -114,47 +110,13 @@ def _on_step(self) -> bool: n_actions = env.action_space.shape[-1] action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) - -class CustomMPL(BaseFeaturesExtractor): - - def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256): - super(CustomMPL, self).__init__(observation_space, features_dim) - # We assume CxHxW images (channels first) - # Re-ordering will be done by pre-preprocessing or wrapper - n_input_channels = observation_space.shape[0] - self.cnn = nn.Sequential( - nn.Linear(n_input_channels, 32), - nn.ReLU(), - nn.Linear(32, 64), - nn.ReLU(), - ) - - # Compute shape by doing one forward pass - with th.no_grad(): - n_flatten = self.cnn( - th.as_tensor(observation_space.sample()[None]).float() - ).shape[1] - - self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) - - def forward(self, observations: th.Tensor) -> th.Tensor: - return self.linear(self.cnn(observations)) - - -policy_kwargs = dict( - features_extractor_class=CustomMPL, - features_extractor_kwargs=dict(features_dim=128, net_arch=[32, 32]), -) - -# policy_kwargs = dict(net_arch=dict(pi=[5, 5], qf=[10, 10])) -# policy_kwargs = dict( activation_fn=th.nn.LeakyReLU, net_arch=[32, 32]) -model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/', policy_kwargs=policy_kwargs) -checkpoint_on_event = CheckpointCallback(save_freq=10000, save_path=f'{timestamp}/checkpoints/') +model = DDPG('MlpPolicy', env, verbose=1, tensorboard_log=f'{timestamp}/') +checkpoint_on_event = CheckpointCallback(save_freq=100000, save_path=f'{timestamp}/checkpoints/') record_env = RecordEnvCallback() -plot_callback = EveryNTimesteps(n_steps=10000, callback=record_env) -model.learn(total_timesteps=50000, callback=[checkpoint_on_event, plot_callback]) +plot_callback = EveryNTimesteps(n_steps=50000, callback=record_env) +model.learn(total_timesteps=500000, callback=[checkpoint_on_event, plot_callback]) -model.save('ddpg_CC2') +model.save('ddpg_CC') del model # remove to demonstrate saving and loading