Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft Actor Critic (SAC) Model #627

Merged
merged 43 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8f1bf23
finish soft actor critic
blahBlahhhJ Apr 28, 2021
8c2145f
added tests
blahBlahhhJ Apr 29, 2021
0c872a1
finish document and init
blahBlahhhJ May 1, 2021
742943e
fix style 1
blahBlahhhJ May 1, 2021
700cdbb
fix style 2
blahBlahhhJ May 1, 2021
08ce087
fix style 3
blahBlahhhJ May 7, 2021
26ccf1c
Merge branch 'master' into feature/596-sac
Borda Jun 24, 2021
a544901
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
71e0dec
formt
Borda Jun 24, 2021
d4abe63
Merge branch 'feature/596-sac' of https://github.com/blahBlahhhJ/ligh…
Borda Jun 24, 2021
557ea57
Apply suggestions from code review
Borda Jun 24, 2021
ad47e34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
8c44f3e
Merge branch 'master' into feature/596-sac
mergify[bot] Jun 25, 2021
c26a88b
Merge branch 'master' into feature/596-sac
Borda Jul 4, 2021
d0e60d3
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 4, 2021
3254dbd
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 4, 2021
d81e8e0
use hyperparameters in hparams
blahBlahhhJ Jul 7, 2021
1a8e73f
Merge branch 'feature/596-sac' of https://github.com/blahBlahhhJ/ligh…
blahBlahhhJ Jul 7, 2021
d101d50
Add CHANGELOG
blahBlahhhJ Jul 7, 2021
c52ea1a
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 7, 2021
48800c9
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 13, 2021
47bb401
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 13, 2021
43daba3
fix test
blahBlahhhJ Jul 20, 2021
bfc7028
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 26, 2021
fd0964b
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 28, 2021
2576333
fix format
blahBlahhhJ Aug 1, 2021
a1ec703
Merge branch 'feature/596-sac' of https://github.com/blahBlahhhJ/ligh…
blahBlahhhJ Aug 1, 2021
4723212
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 9, 2021
05b1084
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 13, 2021
c1660af
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 13, 2021
b207d3c
Merge branch 'master' into feature/596-sac
blahBlahhhJ Aug 13, 2021
73a13d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2021
be19c64
fix __init__
blahBlahhhJ Aug 13, 2021
25aa7e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2021
c6104c0
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 19, 2021
4486569
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 27, 2021
427d5ab
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 27, 2021
cbcc5c0
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 29, 2021
41d7365
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 29, 2021
cccd10d
Merge branch 'master' into feature/596-sac
Sep 7, 2021
bfbae6b
Fix tests
Sep 8, 2021
c0d16fd
Fix reference
Sep 8, 2021
7a0e944
Fix duplication
Sep 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `EMNISTDataModule`, `BinaryEMNISTDataModule`, and `BinaryEMNIST` dataset ([#676](https://github.com/PyTorchLightning/lightning-bolts/pull/676))

- Added Soft Actor Critic (SAC) Model [#627](https://github.com/PyTorchLightning/lightning-bolts/pull/627))

- Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598))
- Added `EMNISTDataModule`, `BinaryEMNISTDataModule`, and `BinaryEMNIST` dataset ([#676](https://github.com/PyTorchLightning/lightning-bolts/pull/676))

- Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598))

- Added Torch ORT Callback [#720](https://github.com/PyTorchLightning/lightning-bolts/pull/720))


- Added SparseML Callback [#724](https://github.com/PyTorchLightning/lightning-bolts/pull/724))


### Changed

- Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701))
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
86 changes: 86 additions & 0 deletions docs/source/reinforce_learn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -764,3 +764,89 @@ Example::

.. autoclass:: pl_bolts.models.rl.AdvantageActorCritic
:noindex:

--------------

Actor-Critic Models
-------------------
The following models are based on Actor Critic. Actor Critic conbines the approaches of value-based learning (the DQN family)
and the policy-based learning (the PG family) by learning the value function as well as the policy distribution. This approach
updates the policy network according to the policy gradient, and updates the value network to fit the discounted rewards.

Actor Critic Key Points:
- Actor outputs a distribution of actions for controlling the agent
- Critic outputs a value of current state for policy update suggestion
- The addition of critic allows the model to do n-step training instead of generating an entire trajectory

Soft Actor Critic (SAC)
^^^^^^^^^^^^^^^^^^^^^^^

Soft Actor Critic model introduced in `Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor<https://arxiv.org/abs/1801.01290>`_
Paper authors: Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, Sergey Levine

Original implementation by: `Jason Wang <https://github.com/blahBlahhhJ>`_

Soft Actor Critic (SAC) is a powerful actor critic algorithm in reinforcement learning. Unlike A2C, SAC's policy outputs a
special continuous distribution for actions, and its critic estimates the Q value instead of the state value, which
means it now takes in not only states but also actions. The new actor allows SAC to support continuous action tasks such
as controlling robots, and the new critic allows SAC to support off-policy learning which is more sample efficient.

The actor has a new objective to maximize entropy to encourage exploration while maximizing the expected rewards.
The critic uses two separate Q functions to "mitigate positive bias" during training by picking the minimum of the
two as the predicted Q value.

Since SAC is off-policy, its algorithm's training step is quite similar to DQN:

1. Initialize one policy network, two Q networks, and two corresponding target Q networks.
2. Run 1 step using action sampled from policy and store the transition into the replay buffer.

.. math::
a \sim tanh(N(\mu_\pi(s), \sigma_\pi(s)))

3. Sample transitions (states, actions, rewards, dones, next states) from the replay buffer.

.. math::
s, a, r, d, s' \sim B

4. Compute actor loss and update policy network.

.. math::
J_\pi = \frac1n\sum_i(\log\pi(\pi(a | s_i) | s_i) - Q_{min}(s_i, \pi(a | s_i)))

5. Compute Q target

.. math::
target_i = r_i + (1 - d_i) \gamma (\min_i Q_{target,i}(s'_i, \pi(a', s'_i)) - log\pi(\pi(a | s'_i) | s'_i))

5. Compute critic loss and update Q network..

.. math::
J_{Q_i} = \frac1n \sum_i(Q_i(s_i, a_i) - target_i)^2

4. Soft update the target Q network using a weighted sum of itself and the Q network.

.. math::
Q_{target,i} := \tau Q_{target,i} + (1-\tau) Q_i

SAC Benefits
~~~~~~~~~~~~~~~~~~~

- More sample efficient due to off-policy training

- Supports continuous action space

SAC Results
~~~~~~~~~~~~~~~~

.. image:: _images/rl_benchmark/pendulum_sac_results.jpg
:width: 300
:alt: SAC Results

Example::
from pl_bolts.models.rl import SAC
sac = SAC("Pendulum-v0")
trainer = Trainer()
trainer.fit(sac)

.. autoclass:: pl_bolts.models.rl.SAC
:noindex:
2 changes: 2 additions & 0 deletions pl_bolts/models/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN
from pl_bolts.models.rl.per_dqn_model import PERDQN
from pl_bolts.models.rl.reinforce_model import Reinforce
from pl_bolts.models.rl.sac_model import SAC
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient

__all__ = [
Expand All @@ -15,5 +16,6 @@
"NoisyDQN",
"PERDQN",
"Reinforce",
"SAC",
"VanillaPolicyGradient",
]
45 changes: 45 additions & 0 deletions pl_bolts/models/rl/common/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,48 @@ def __call__(self, states: Tensor, device: str) -> List[int]:
actions = [np.random.choice(len(prob), p=prob) for prob in prob_np]

return actions


class SoftActorCriticAgent(Agent):
"""Actor-Critic based agent that returns a continuous action based on the policy."""

def __call__(self, states: Tensor, device: str) -> List[float]:
"""Takes in the current state and returns the action based on the agents policy.

Args:
states: current state of the environment
device: the device used for the current batch

Returns:
action defined by policy
"""
if not isinstance(states, list):
states = [states]

if not isinstance(states, Tensor):
states = torch.tensor(states, device=device)

dist = self.net(states)
actions = [a for a in dist.sample().cpu().numpy()]

return actions

def get_action(self, states: Tensor, device: str) -> List[float]:
"""Get the action greedily (without sampling)

Args:
states: current state of the environment
device: the device used for the current batch

Returns:
action defined by policy
"""
if not isinstance(states, list):
states = [states]

if not isinstance(states, Tensor):
states = torch.tensor(states, device=device)

actions = [self.net.get_action(states).cpu().numpy()]

return actions
62 changes: 62 additions & 0 deletions pl_bolts/models/rl/common/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Distributions used in some continuous RL algorithms."""
import torch


class TanhMultivariateNormal(torch.distributions.MultivariateNormal):
"""The distribution of X is an affine of tanh applied on a normal distribution.

X = action_scale * tanh(Z) + action_bias
Z ~ Normal(mean, variance)
"""

def __init__(self, action_bias, action_scale, **kwargs):
super().__init__(**kwargs)

self.action_bias = action_bias
self.action_scale = action_scale

def rsample_with_z(self, sample_shape=torch.Size()):
"""Samples X using reparametrization trick with the intermediate variable Z.

Returns:
Sampled X and Z
"""
z = super().rsample()
return self.action_scale * torch.tanh(z) + self.action_bias, z

def log_prob_with_z(self, value, z):
"""Computes the log probability of a sampled X.

Refer to the original paper of SAC for more details in equation (20), (21)

Args:
value: the value of X
z: the value of Z
Returns:
Log probability of the sample
"""
value = (value - self.action_bias) / self.action_scale
z_logprob = super().log_prob(z)
correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1)
return z_logprob - correction

def rsample_and_log_prob(self, sample_shape=torch.Size()):
"""Samples X and computes the log probability of the sample.

Returns:
Sampled X and log probability
"""
z = super().rsample()
z_logprob = super().log_prob(z)
value = torch.tanh(z)
correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1)
return self.action_scale * value + self.action_bias, z_logprob - correction

def rsample(self, sample_shape=torch.Size()):
fz, z = self.rsample_with_z(sample_shape)
return fz

def log_prob(self, value):
value = (value - self.action_bias) / self.action_scale
z = torch.log(1 + value) / 2 - torch.log(1 - value) / 2
return self.log_prob_with_z(value, z)
62 changes: 61 additions & 1 deletion pl_bolts/models/rl/common/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import numpy as np
import torch
from torch import Tensor, nn
from torch import FloatTensor, Tensor, nn
from torch.distributions import Categorical, Normal
from torch.nn import functional as F

from pl_bolts.models.rl.common.distributions import TanhMultivariateNormal


class CNN(nn.Module):
"""Simple MLP network."""
Expand Down Expand Up @@ -84,6 +86,64 @@ def forward(self, input_x):
return self.net(input_x.float())


class ContinuousMLP(nn.Module):
"""MLP network that outputs continuous value via Gaussian distribution."""

def __init__(
self,
input_shape: Tuple[int],
n_actions: int,
hidden_size: int = 128,
action_bias: int = 0,
action_scale: int = 1,
):
"""
Args:
input_shape: observation shape of the environment
n_actions: dimension of actions in the environment
hidden_size: size of hidden layers
action_bias: the center of the action space
action_scale: the scale of the action space
"""
super().__init__()
self.action_bias = action_bias
self.action_scale = action_scale

self.shared_net = nn.Sequential(
nn.Linear(input_shape[0], hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU()
)
self.mean_layer = nn.Linear(hidden_size, n_actions)
self.logstd_layer = nn.Linear(hidden_size, n_actions)

def forward(self, x: FloatTensor) -> TanhMultivariateNormal:
"""Forward pass through network. Calculates the action distribution.

Args:
x: input to network
Returns:
action distribution
"""
x = self.shared_net(x.float())
batch_mean = self.mean_layer(x)
logstd = torch.clamp(self.logstd_layer(x), -20, 2)
batch_scale_tril = torch.diag_embed(torch.exp(logstd))
return TanhMultivariateNormal(
action_bias=self.action_bias, action_scale=self.action_scale, loc=batch_mean, scale_tril=batch_scale_tril
)

def get_action(self, x: FloatTensor) -> Tensor:
"""Get the action greedily (without sampling)

Args:
x: input to network
Returns:
mean action
"""
x = self.shared_net(x.float())
batch_mean = self.mean_layer(x)
return self.action_scale * torch.tanh(batch_mean) + self.action_bias


class ActorCriticMLP(nn.Module):
"""MLP network with heads for actor and critic."""

Expand Down
Loading