From 407fa8bc5b3cd0c777a0363df7c7519a80197895 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 15 Jan 2019 14:27:13 +0100 Subject: [PATCH 1/6] Update tests + changelog --- docs/misc/changelog.rst | 3 ++- tests/test_custom_policy.py | 50 ++++++++++++++++++++++++------------- tests/test_lstm_policy.py | 4 +-- tests/test_vec_envs.py | 2 +- tests/test_vec_normalize.py | 3 ++- 5 files changed, 39 insertions(+), 23 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2ac34da84e..52b45d08d2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -18,7 +18,8 @@ Pre-Release 2.4.0a (WIP) - added optional parameter to action_probability for likelihood calculation of given action being taken. - added more flexible custom LSTM policies - added auto entropy coefficient optimization for SAC -- clip continuous actions at test time too for all algorithms (except SAC/DDPG where it is not needed) +- clip continuous actions at test time too for all algorithms (except SAC/DDPG where it is not needed +- added a mean to pass kwargs to policy when creating a model (+ save those kwargs) Release 2.3.0 (2018-12-05) diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 28fd46a522..65c6b933d0 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -13,39 +13,53 @@ N_TRIALS = 100 + class CustomCommonPolicy(FeedForwardPolicy): def __init__(self, *args, **kwargs): + # Default value + if 'net_arch' not in kwargs: + kwargs['net_arch'] = [8, dict(vf=[8, 8], pi=[8, 8])] super(CustomCommonPolicy, self).__init__(*args, **kwargs, - net_arch=[8, dict(vf=[8, 8], pi=[8, 8])], - feature_extraction="mlp") + feature_extraction="mlp") + class CustomDQNPolicy(DQNPolicy): def __init__(self, *args, **kwargs): + # Default value + if 'layers' not in kwargs: + kwargs['layers'] = [8, 8] super(CustomDQNPolicy, self).__init__(*args, **kwargs, - layers=[8, 8], - feature_extraction="mlp") + feature_extraction="mlp") + class CustomDDPGPolicy(DDPGPolicy): def __init__(self, *args, **kwargs): + # Default value + if 'layers' not in kwargs: + kwargs['layers'] = [8, 8] super(CustomDDPGPolicy, self).__init__(*args, **kwargs, - layers=[8, 8], - feature_extraction="mlp") + feature_extraction="mlp") + class CustomSACPolicy(SACPolicy): def __init__(self, *args, **kwargs): + # Default value + if 'layers' not in kwargs: + kwargs['layers'] = [8, 8] super(CustomSACPolicy, self).__init__(*args, **kwargs, - layers=[8, 8], - feature_extraction="mlp") + feature_extraction="mlp") + +# MODEL_CLASS, POLICY_CLASS, POLICY_KWARGS MODEL_DICT = { - 'a2c': (A2C, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), + 'a2c': (A2C, CustomCommonPolicy, dict(act_fun=tf.nn.relu, net_arch=[12, dict(vf=[16], pi=[8])])), 'acer': (ACER, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), 'acktr': (ACKTR, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), - 'dqn': (DQN, CustomDQNPolicy, dict()), - 'ddpg': (DDPG, CustomDDPGPolicy, dict()), - 'ppo1': (PPO1, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), - 'ppo2': (PPO2, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), - 'sac': (SAC, CustomSACPolicy, dict()), + 'dqn': (DQN, CustomDQNPolicy, dict(layers=[4, 4], dueling=False)), + 'ddpg': (DDPG, CustomDDPGPolicy, dict(layers=[16, 16], layer_norm=False)), + 'ppo1': (PPO1, CustomCommonPolicy, dict(act_fun=tf.nn.relu, net_arch=[8, 4])), + 'ppo2': (PPO2, CustomCommonPolicy, dict(act_fun=tf.nn.relu, net_arch=[4, 4])), + 'sac': (SAC, CustomSACPolicy, dict(layers=[16, 16])), 'trpo': (TRPO, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), } @@ -54,7 +68,7 @@ def __init__(self, *args, **kwargs): def test_custom_policy(model_name): """ Test if the algorithm (with a custom policy) can be loaded and saved without any issues. - :param model_class: (BaseRLModel) A RL model + :param model_name: (str) A RL model """ try: @@ -78,7 +92,7 @@ def test_custom_policy(model_name): model.save("./test_model") del model, env # loading - model = model_class.load("./test_model", policy=policy) + _ = model_class.load("./test_model", policy=policy) finally: if os.path.exists("./test_model"): @@ -89,7 +103,7 @@ def test_custom_policy(model_name): def test_custom_policy_kwargs(model_name): """ Test if the algorithm (with a custom policy) can be loaded and saved without any issues. - :param model_class: (BaseRLModel) A RL model + :param model_name: (str) A RL model """ try: @@ -119,7 +133,7 @@ def test_custom_policy_kwargs(model_name): # Load wit different wrong policy_kwargs with pytest.raises(ValueError): - model = model_class.load("./test_model", policy=policy, env=env, policy_kwargs=dict(wrong="kwargs")) + _ = model_class.load("./test_model", policy=policy, env=env, policy_kwargs=dict(wrong="kwargs")) finally: if os.path.exists("./test_model"): diff --git a/tests/test_lstm_policy.py b/tests/test_lstm_policy.py index 0648474a14..79e0ae6875 100644 --- a/tests/test_lstm_policy.py +++ b/tests/test_lstm_policy.py @@ -9,7 +9,7 @@ class CustomLSTMPolicy1(LstmPolicy): def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=128, reuse=False, **_kwargs): super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse, net_arch=[8, 'lstm', 8], - layer_norm=False, feature_extraction="mlp", **_kwargs) + layer_norm=False, feature_extraction="mlp", **_kwargs) class CustomLSTMPolicy2(LstmPolicy): @@ -58,7 +58,7 @@ def test_lstm_policy(model_class, policy): model.save("./test_model") del model, env # loading - model = model_class.load("./test_model", policy=policy) + _ = model_class.load("./test_model", policy=policy) finally: if os.path.exists("./test_model"): diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 1597c6f9f1..6d549de2d3 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -23,7 +23,7 @@ def reset(self): return self.state def step(self, action): - reward = self._get_reward(action) + reward = 1 self._choose_next_state() self.current_step += 1 done = self.current_step >= self.ep_length diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 609a483d8f..6d39bbd47e 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -56,4 +56,5 @@ def test_mpi_moments(): test running mean std function """ subprocess.check_call(['mpirun', '--allow-run-as-root', '-np', '3', 'python', '-c', - 'from stable_baselines.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()']) + 'from stable_baselines.common.mpi_moments ' + 'import _helper_runningmeanstd; _helper_runningmeanstd()']) From 602c5ca62e3afd9d3beb730e0461f81239d10801 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 15 Jan 2019 18:03:28 +0100 Subject: [PATCH 2/6] Add example + kwargs check + fix for TRPO --- docs/guide/custom_policy.rst | 28 ++++++++++++++++++++++++++- docs/misc/changelog.rst | 2 +- docs/modules/policies.rst | 6 +++--- stable_baselines/common/policies.py | 14 ++++++++++++++ stable_baselines/ddpg/policies.py | 2 ++ stable_baselines/deepq/policies.py | 3 +++ stable_baselines/sac/policies.py | 2 ++ stable_baselines/trpo_mpi/trpo_mpi.py | 4 ++-- tests/test_custom_policy.py | 4 ++++ 9 files changed, 58 insertions(+), 7 deletions(-) diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 1131bb400c..fee6c5be69 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -5,7 +5,33 @@ Custom Policy Network Stable baselines provides default policy networks (see :ref:`Policies ` ) for images (CNNPolicies) and other type of input features (MlpPolicies). -However, you can also easily define a custom architecture for the policy (or value) network: + +One way of customising the policy network architecture is to pass argument using ``policy_kwargs`` parameter: + +.. code-block:: python + + import gym + import tensorflow as tf + + from stable_baselines import PPO2 + + # Custom MLP policy of two layers of size 32 each with tanh activation function + policy_kwargs = dict(act_fun=tf.nn.tanh, net_arch=[32, 32]) + # Create the agent + model = PPO2("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) + # Retrieve the environment + env = model.get_env() + # Train the agent + model.learn(total_timesteps=100000) + # Save the agent + model.save("ppo2-cartpole") + + del model + # the policy_kwargs are automatically loaded + model = PPO2.load("ppo2-cartpole") + + +You can also easily define a custom architecture for the policy (or value) network: .. code-block:: python diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 52b45d08d2..f56edae903 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -18,7 +18,7 @@ Pre-Release 2.4.0a (WIP) - added optional parameter to action_probability for likelihood calculation of given action being taken. - added more flexible custom LSTM policies - added auto entropy coefficient optimization for SAC -- clip continuous actions at test time too for all algorithms (except SAC/DDPG where it is not needed +- clip continuous actions at test time too for all algorithms (except SAC/DDPG where it is not needed) - added a mean to pass kwargs to policy when creating a model (+ save those kwargs) diff --git a/docs/modules/policies.rst b/docs/modules/policies.rst index 448df7c187..871241e257 100644 --- a/docs/modules/policies.rst +++ b/docs/modules/policies.rst @@ -6,9 +6,9 @@ Policy Networks =============== Stable-baselines provides a set of default policies, that can be used with most action spaces. -To customize the default policies, you can specify the `policy_kwargs` parameter to the model class you use. -Those kwargs are then passed to the policy on instantiation. -If you need more control on the policy architecture, You can also create a custom policy (see :ref:`custom_policy`). +To customize the default policies, you can specify the ``policy_kwargs`` parameter to the model class you use. +Those kwargs are then passed to the policy on instantiation (see :ref:`custom_policy` for an example). +If you need more control on the policy architecture, you can also create a custom policy (see :ref:`custom_policy`). .. note:: diff --git a/stable_baselines/common/policies.py b/stable_baselines/common/policies.py index 5e86bc7acc..57f92c77a7 100644 --- a/stable_baselines/common/policies.py +++ b/stable_baselines/common/policies.py @@ -123,6 +123,16 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals self.ob_space = ob_space self.ac_space = ac_space + def _kwargs_check(self, feature_extraction, kwargs): + """ + Ensure that the user is not passing wrong keywords + when using policy_kwargs + :param feature_extraction: (str) + :param kwargs: (dict) + """ + if feature_extraction == 'mlp' and len(kwargs) > 0: + raise ValueError("Unknown keywords for policy: {}".format(kwargs)) + def step(self, obs, state=None, mask=None): """ Returns the policy for a single step @@ -255,6 +265,8 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256 super(LstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse, scale=(feature_extraction == "cnn")) + self._kwargs_check(feature_extraction, kwargs) + with tf.variable_scope("input", reuse=True): self.masks_ph = tf.placeholder(tf.float32, [n_batch], name="masks_ph") # mask (done t-1) # n_lstm * 2 dim because of the cell and hidden states of the LSTM @@ -394,6 +406,8 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=(feature_extraction == "cnn")) + self._kwargs_check(feature_extraction, kwargs) + if layers is not None: warnings.warn("Usage of the `layers` parameter is deprecated! Use net_arch instead " "(it has a different semantics though).", DeprecationWarning) diff --git a/stable_baselines/ddpg/policies.py b/stable_baselines/ddpg/policies.py index 9483996016..062ab01f34 100644 --- a/stable_baselines/ddpg/policies.py +++ b/stable_baselines/ddpg/policies.py @@ -107,6 +107,8 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals cnn_extractor=nature_cnn, feature_extraction="cnn", layer_norm=False, **kwargs): super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=(feature_extraction == "cnn")) + + self._kwargs_check(feature_extraction, kwargs) self.layer_norm = layer_norm self.feature_extraction = feature_extraction self.cnn_kwargs = kwargs diff --git a/stable_baselines/deepq/policies.py b/stable_baselines/deepq/policies.py index cb759c9ffb..7d9709115b 100644 --- a/stable_baselines/deepq/policies.py +++ b/stable_baselines/deepq/policies.py @@ -93,6 +93,9 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, dueling=dueling, reuse=reuse, scale=(feature_extraction == "cnn"), obs_phs=obs_phs) + + self._kwargs_check(feature_extraction, kwargs) + if layers is None: layers = [64, 64] diff --git a/stable_baselines/sac/policies.py b/stable_baselines/sac/policies.py index a8389a7687..ea645cc7cb 100644 --- a/stable_baselines/sac/policies.py +++ b/stable_baselines/sac/policies.py @@ -183,6 +183,8 @@ def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, r layer_norm=False, **kwargs): super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=(feature_extraction == "cnn")) + + self._kwargs_check(feature_extraction, kwargs) self.layer_norm = layer_norm self.feature_extraction = feature_extraction self.cnn_kwargs = kwargs diff --git a/stable_baselines/trpo_mpi/trpo_mpi.py b/stable_baselines/trpo_mpi/trpo_mpi.py index 4cc1e9cdf7..c88576e432 100644 --- a/stable_baselines/trpo_mpi/trpo_mpi.py +++ b/stable_baselines/trpo_mpi/trpo_mpi.py @@ -119,12 +119,12 @@ def setup_model(self): # Construct network for new policy self.policy_pi = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1, - None, reuse=False) + None, reuse=False, **self.policy_kwargs) # Network for old policy with tf.variable_scope("oldpi", reuse=False): old_policy = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1, - None, reuse=False) + None, reuse=False, **self.policy_kwargs) with tf.variable_scope("loss", reuse=False): atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 65c6b933d0..cbf1493761 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -110,6 +110,10 @@ def test_custom_policy_kwargs(model_name): model_class, policy, policy_kwargs = MODEL_DICT[model_name] env = 'MountainCarContinuous-v0' if model_name in ['ddpg', 'sac'] else 'CartPole-v1' + # Should raise an error when a wrong keyword is passed + with pytest.raises(ValueError): + model_class(policy, env, policy_kwargs=dict(this_throws_error='maybe')) + # create and train model = model_class(policy, env, policy_kwargs=policy_kwargs) model.learn(total_timesteps=100, seed=0) From 4acf953f357ed62986301440b01351864da385a6 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 15 Jan 2019 22:41:50 +0100 Subject: [PATCH 3/6] Fix DQN examples + fix for codacy --- docs/misc/changelog.rst | 1 + stable_baselines/common/policies.py | 5 +++-- .../deepq/experiments/custom_cartpole.py | 2 +- .../deepq/experiments/train_mountaincar.py | 14 +++----------- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index f56edae903..72f02db2fb 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -20,6 +20,7 @@ Pre-Release 2.4.0a (WIP) - added auto entropy coefficient optimization for SAC - clip continuous actions at test time too for all algorithms (except SAC/DDPG where it is not needed) - added a mean to pass kwargs to policy when creating a model (+ save those kwargs) +- fixed DQN examples in DQN folder Release 2.3.0 (2018-12-05) diff --git a/stable_baselines/common/policies.py b/stable_baselines/common/policies.py index 57f92c77a7..8d314b571e 100644 --- a/stable_baselines/common/policies.py +++ b/stable_baselines/common/policies.py @@ -123,7 +123,8 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals self.ob_space = ob_space self.ac_space = ac_space - def _kwargs_check(self, feature_extraction, kwargs): + @staticmethod + def _kwargs_check(feature_extraction, kwargs): """ Ensure that the user is not passing wrong keywords when using policy_kwargs @@ -350,7 +351,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256 for idx, vf_layer_size in enumerate(value_only_layers): if vf_layer_size == "lstm": raise NotImplementedError("LSTMs are only supported in the shared part of the value function " - "network.") + "network.") assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers." latent_value = act_fun( linear(latent_value, "vf_fc{}".format(idx), vf_layer_size, init_scale=np.sqrt(2))) diff --git a/stable_baselines/deepq/experiments/custom_cartpole.py b/stable_baselines/deepq/experiments/custom_cartpole.py index b039ec5923..3ff006a228 100644 --- a/stable_baselines/deepq/experiments/custom_cartpole.py +++ b/stable_baselines/deepq/experiments/custom_cartpole.py @@ -15,7 +15,7 @@ class CustomPolicy(FeedForwardPolicy): def __init__(self, *args, **kwargs): super(CustomPolicy, self).__init__(*args, **kwargs, - net_arch=[dict(vf=[64], pi=[64])], + layers=[64], feature_extraction="mlp") diff --git a/stable_baselines/deepq/experiments/train_mountaincar.py b/stable_baselines/deepq/experiments/train_mountaincar.py index 16abef9651..d605d1ad76 100644 --- a/stable_baselines/deepq/experiments/train_mountaincar.py +++ b/stable_baselines/deepq/experiments/train_mountaincar.py @@ -3,15 +3,6 @@ import gym from stable_baselines.deepq import DQN -from stable_baselines.deepq.policies import FeedForwardPolicy - - -class CustomPolicy(FeedForwardPolicy): - def __init__(self, *args, **kwargs): - super(CustomPolicy, self).__init__(*args, **kwargs, - net_arch=[dict(pi=[64], vf=[64])], - layer_norm=True, - feature_extraction="mlp") def main(args): @@ -24,13 +15,14 @@ def main(args): # using layer norm policy here is important for parameter space noise! model = DQN( - policy=CustomPolicy, + policy="LnMlpPolicy", env=env, learning_rate=1e-3, buffer_size=50000, exploration_fraction=0.1, exploration_final_eps=0.1, - param_noise=True + param_noise=True, + policy_kwargs=dict(layers=[64]) ) model.learn(total_timesteps=args.max_timesteps) From d62659ad8c8a0a0b3943544800797dc46d4f48e7 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 16 Jan 2019 11:44:17 +0100 Subject: [PATCH 4/6] [ci skip] Add comments --- stable_baselines/common/policies.py | 9 ++++++++- tests/test_custom_policy.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/stable_baselines/common/policies.py b/stable_baselines/common/policies.py index 57f92c77a7..a607129b37 100644 --- a/stable_baselines/common/policies.py +++ b/stable_baselines/common/policies.py @@ -126,10 +126,17 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals def _kwargs_check(self, feature_extraction, kwargs): """ Ensure that the user is not passing wrong keywords - when using policy_kwargs + when using policy_kwargs. + :param feature_extraction: (str) :param kwargs: (dict) """ + # When using policy_kwargs parameter on model creation, + # all keywords arguments must be consumed by the policy constructor except + # the ones for the cnn_extractor network (cf nature_cnn()), where the keywords arguments + # are not passed explicitely (using **kwargs to forward the arguments) + # that's why there should be not kwargs left when using the mlp_extractor + # (in that case the keywords arguments are passed explicitely) if feature_extraction == 'mlp' and len(kwargs) > 0: raise ValueError("Unknown keywords for policy: {}".format(kwargs)) diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index cbf1493761..e292a1f126 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -135,7 +135,7 @@ def test_custom_policy_kwargs(model_name): model.learn(total_timesteps=100, seed=0) del model - # Load wit different wrong policy_kwargs + # Load with different wrong policy_kwargs with pytest.raises(ValueError): _ = model_class.load("./test_model", policy=policy, env=env, policy_kwargs=dict(wrong="kwargs")) From 3fec9eb885c31fadd2189d0648f380578c4cfbf5 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 17 Jan 2019 14:46:31 +0100 Subject: [PATCH 5/6] Add a note about custom class vs policy kwargs --- docs/guide/custom_policy.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index fee6c5be69..16585f92e8 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -6,7 +6,8 @@ Custom Policy Network Stable baselines provides default policy networks (see :ref:`Policies ` ) for images (CNNPolicies) and other type of input features (MlpPolicies). -One way of customising the policy network architecture is to pass argument using ``policy_kwargs`` parameter: +One way of customising the policy network architecture is to pass arguments when creating the model, +using ``policy_kwargs`` parameter: .. code-block:: python @@ -33,6 +34,13 @@ One way of customising the policy network architecture is to pass argument using You can also easily define a custom architecture for the policy (or value) network: +.. note:: + + Defining a custom policy class is equivalent to passing ``policy_kwargs``. However, + it lets you name the policy and so makes usually the code clearer. ``policy_kwargs`` should be rather used + when doing hyperparameter search. + + .. code-block:: python import gym From 0ebfeb5d9b690987db6aa46fe66838c5aecd245d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 17 Jan 2019 15:06:50 +0100 Subject: [PATCH 6/6] The activation function can now be customized for DQN, DDPG and SAC --- docs/misc/changelog.rst | 1 + stable_baselines/common/policies.py | 3 ++- stable_baselines/ddpg/policies.py | 6 ++++-- stable_baselines/deepq/policies.py | 8 ++++---- stable_baselines/sac/policies.py | 6 ++++-- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 72f02db2fb..4169197863 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -21,6 +21,7 @@ Pre-Release 2.4.0a (WIP) - clip continuous actions at test time too for all algorithms (except SAC/DDPG where it is not needed) - added a mean to pass kwargs to policy when creating a model (+ save those kwargs) - fixed DQN examples in DQN folder +- added possibility to pass activation function for DDPG, DQN and SAC Release 2.3.0 (2018-12-05) diff --git a/stable_baselines/common/policies.py b/stable_baselines/common/policies.py index 5e744ce1c0..288adecd0f 100644 --- a/stable_baselines/common/policies.py +++ b/stable_baselines/common/policies.py @@ -261,6 +261,7 @@ class LstmPolicy(ActorCriticPolicy): :param layers: ([int]) The size of the Neural network before the LSTM layer (if None, default to [64, 64]) :param net_arch: (list) Specification of the actor-critic policy network architecture. Notation similar to the format described in mlp_extractor but with additional support for a 'lstm' entry in the shared network part. + :param act_fun: (tf.func) the activation function to use in the neural network. :param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction :param layer_norm: (bool) Whether or not to use layer normalizing LSTMs :param feature_extraction: (str) The feature extraction type ("cnn" or "mlp") @@ -403,7 +404,7 @@ class FeedForwardPolicy(ActorCriticPolicy): (if None, default to [64, 64]) :param net_arch: (list) Specification of the actor-critic policy network architecture (see mlp_extractor documentation for details). - :param act_fun: the activation function to use in the neural network. + :param act_fun: (tf.func) the activation function to use in the neural network. :param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction :param feature_extraction: (str) The feature extraction type ("cnn" or "mlp") :param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction diff --git a/stable_baselines/ddpg/policies.py b/stable_baselines/ddpg/policies.py index 062ab01f34..db5c82af94 100644 --- a/stable_baselines/ddpg/policies.py +++ b/stable_baselines/ddpg/policies.py @@ -100,11 +100,13 @@ class FeedForwardPolicy(DDPGPolicy): :param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction :param feature_extraction: (str) The feature extraction type ("cnn" or "mlp") :param layer_norm: (bool) enable layer normalisation + :param act_fun: (tf.func) the activation function to use in the neural network. :param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction """ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, layers=None, - cnn_extractor=nature_cnn, feature_extraction="cnn", layer_norm=False, **kwargs): + cnn_extractor=nature_cnn, feature_extraction="cnn", + layer_norm=False, act_fun=tf.nn.relu, **kwargs): super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=(feature_extraction == "cnn")) @@ -121,7 +123,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals assert len(layers) >= 1, "Error: must have at least one hidden layer for the policy." - self.activ = tf.nn.relu + self.activ = act_fun def make_actor(self, obs=None, reuse=False, scope="pi"): if obs is None: diff --git a/stable_baselines/deepq/policies.py b/stable_baselines/deepq/policies.py index 7d9709115b..1e71eee5c1 100644 --- a/stable_baselines/deepq/policies.py +++ b/stable_baselines/deepq/policies.py @@ -84,12 +84,13 @@ class FeedForwardPolicy(DQNPolicy): and the processed observation placeholder respectivly :param layer_norm: (bool) enable layer normalisation :param dueling: (bool) if true double the output MLP to compute a baseline for action scores + :param act_fun: (tf.func) the activation function to use in the neural network. :param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction """ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, layers=None, cnn_extractor=nature_cnn, feature_extraction="cnn", - obs_phs=None, layer_norm=False, dueling=True, **kwargs): + obs_phs=None, layer_norm=False, dueling=True, act_fun=tf.nn.relu, **kwargs): super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, dueling=dueling, reuse=reuse, scale=(feature_extraction == "cnn"), obs_phs=obs_phs) @@ -105,14 +106,13 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals extracted_features = cnn_extractor(self.processed_obs, **kwargs) action_out = extracted_features else: - activ = tf.nn.relu extracted_features = tf.layers.flatten(self.processed_obs) action_out = extracted_features for layer_size in layers: action_out = tf_layers.fully_connected(action_out, num_outputs=layer_size, activation_fn=None) if layer_norm: action_out = tf_layers.layer_norm(action_out, center=True, scale=True) - action_out = activ(action_out) + action_out = act_fun(action_out) action_scores = tf_layers.fully_connected(action_out, num_outputs=self.n_actions, activation_fn=None) @@ -123,7 +123,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals state_out = tf_layers.fully_connected(state_out, num_outputs=layer_size, activation_fn=None) if layer_norm: state_out = tf_layers.layer_norm(state_out, center=True, scale=True) - state_out = tf.nn.relu(state_out) + state_out = act_fun(state_out) state_score = tf_layers.fully_connected(state_out, num_outputs=1, activation_fn=None) action_scores_mean = tf.reduce_mean(action_scores, axis=1) action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, axis=1) diff --git a/stable_baselines/sac/policies.py b/stable_baselines/sac/policies.py index ea645cc7cb..53fefd4db9 100644 --- a/stable_baselines/sac/policies.py +++ b/stable_baselines/sac/policies.py @@ -175,12 +175,14 @@ class FeedForwardPolicy(SACPolicy): :param feature_extraction: (str) The feature extraction type ("cnn" or "mlp") :param layer_norm: (bool) enable layer normalisation :param reg_weight: (float) Regularization loss weight for the policy parameters + :param reg_weight: (float) Regularization loss weight for the policy parameters + :param act_fun: (tf.func) the activation function to use in the neural network. :param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction """ def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, reuse=False, layers=None, cnn_extractor=nature_cnn, feature_extraction="cnn", reg_weight=0.0, - layer_norm=False, **kwargs): + layer_norm=False, act_fun=tf.nn.relu, **kwargs): super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=(feature_extraction == "cnn")) @@ -199,7 +201,7 @@ def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, r assert len(layers) >= 1, "Error: must have at least one hidden layer for the policy." - self.activ_fn = tf.nn.relu + self.activ_fn = act_fun def make_actor(self, obs=None, reuse=False, scope="pi"): if obs is None: