Skip to content
/ tf-sac Public

TensorFlow implementation of Soft Actor-Critic

Notifications You must be signed in to change notification settings

mrahtz/tf-sac

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

60 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Build Status

TensorFlow Soft Actor-Critic

TensorFlow implementation of Haarnoja et al.'s Soft Actor-Critic.

Results

Testing on HalfCheetah-v2 and Walker2d-v2:

Compared to Spinning Up's implementation, we achieve about half the performance on HalfCheetah and about the same performance on Walker2d:

I'm not sure why HalfCheetah does comparatively so much worse. It may be because we train on every environment step instead of in a batch at the end of each episode, but I haven't investigated in detail.

Usage

Setup

To set up a virtual environment and install requirements, install Pipenv then just do:

$ pipenv sync

If you want to run MuJoCo environments, you'll also need to install mujoco-py version 1.50.1.68.

Training

Basic usage is

$ pipenv run python -m sac.train

We use Sacred for configuration management. See config.py for available parameters, and set parameters using e.g.:

$ pipenv run python -m sac.train with env_id=HalfCheetah-v2 render=True

A run directory will be created in runs/ containing TensorBoard metrics.

To view a trained agent:

$ pipenv run python -m sac.play env_id runs/x/checkpoints/model-yyy.pkl

Tests

To run smoke tests and unit tests, respectively:

$ pipenv run python tests.py Smoke
$ pipenv run python tests.py Unit

Unit tests cover:

  • Target network update
  • Gaussian policy log prob calculation
  • Action limit

Lessons learned

Stochastic policies

The main thing which surprised me with soft-actor critic was how hard it was to implement the stochastic tanh-Gaussian policy correctly. There are a lot of fiddly details (e.g. do you tanh the mean before or after you sample?) and it's tricky to figure out how the tanh modifies the PDF (see https://math.stackexchange.com/a/3283855/468726).

tanh precision

tanh reaches the limit of float32 precision surprisingly quickly:

>>> np.tanh(19)
0.99999999999999989
>>> np.tanh(20)
1.0

Be really careful if you need to tanh something and then later arctanh it, or you'll get numerical errors.

Keras

Since tf.layers.dense is now deprecated in favour of Keras layers, this was one of my first projects using Keras in anger.

I love how easy Keras's model paradigm makes it to reuse layers. For example, to use the same set of policy weights to calculate an action for two different observations:

pi = PolicyModel()
action1 = pi(obs1)
action2 = pi(obs2)

PolicyModel() instatiates a set of weights which are then held in pi. The resulting transformation can then be applied by just calling pi on other tensors.

I preferred subclassing Model than calling Model on a bunch of Keras layers. You don't have worry about input shape or much around the Lambda layers as much. For example:

def get_pi_model(obs_dim):
    obs = Input(shape=(obs_dim,))
    h = Dense(16)(obs)
    act = Dense(1)(h)
    tanh_act = Lambda(lambda x: tf.tanh(x))(act)
    return Model(inputs=obs, outputs=tanh_act)

obs = tf.placeholder(tf.float32, [None, obs_dim])
pi = get_pi_model(obs_dim)
pi(obs)

vs.

class Pi(Model):
    def __init__(self):
        super().__init__()
        self.h = Dense(16)
        self.act = Dense(1)

    def call(self, inputs, **kwargs):
        x = self.h(inputs)
        x = self.act(x)
        x = tf.tanh(x)
        return x

obs = tf.placeholder(tf.float32, [None, obs_dim])
pi = Pi()
pi(obs)

Graph mistakes

Take a look at this code fragment.

q_backup = rews + discount * (1 - done) * v_targ_obs2
q_loss = (q - q_backup) ** 2

Does it give you the heebie-jeebies? Well it should!

Why? Don't forget to `tf.stop_gradient` your Bellman backups!

What about this one?

q_loss = (q_obs1 - q_backup) ** 2
train_op = tf.train.AdamOptimizer().minimize(q_loss)
What's wrong? Don't forget to `tf.reduce_mean` your losses!

Finally - say you're implementing a DDPG-style graph, where the policy loss is based on the Q output. What about this one?

pi_train = tf.train.AdamOptimizer(learning_rate=lr).minimize(pi_loss)
Why be nervous? Your optimizer will try and modify Q parameters, too! Don't forget to limit optimizers to only the variables you care about!

PyCharm debugging templates

This implementation required a lot of inspection of intermediate values in the graph, but tf.print needs so much boilerplate:

with tf.control_dependencies([tf.print(x, summarize=999)]):
    x = tf.identity(x)

Amazingly, it turns out PyCharm has a templating system that supports substitutions!

With an appropriate template

with tf.control_dependencies([tf.print('$SELECTION$', $SELECTION$, summarize=999)]):
    $SELECTION$ = tf.identity($SELECTION$)

adding a print op takes only a few seconds:

About

TensorFlow implementation of Soft Actor-Critic

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages