Skip to content

Commit

Permalink
Enhance ddpg (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
Curt-Park authored and MrSyee committed Apr 17, 2019
1 parent 43e8a92 commit 4248057
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 13 deletions.
13 changes: 9 additions & 4 deletions algorithms/ddpg/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb

Expand Down Expand Up @@ -101,13 +102,13 @@ def select_action(self, state: np.ndarray) -> np.ndarray:
):
return self.env.action_space.sample()

selected_action = self.actor(state)
selected_action = self.actor(state).detach().cpu().numpy()

if not self.args.test:
selected_action += torch.FloatTensor(self.noise.sample()).to(device)
selected_action = torch.clamp(selected_action, -1.0, 1.0)
noise = self.noise.sample()
selected_action = np.clip(selected_action + noise, -1.0, 1.0)

return selected_action.detach().cpu().numpy()
return selected_action

# pylint: disable=no-self-use
def _preprocess_state(self, state: np.ndarray) -> torch.Tensor:
Expand Down Expand Up @@ -150,17 +151,21 @@ def update_model(self) -> Tuple[torch.Tensor, torch.Tensor]:
curr_returns = curr_returns.to(device)

# train critic
gradient_clip_cr = self.hyper_params["GRADIENT_CLIP_CR"]
values = self.critic(torch.cat((states, actions), dim=-1))
critic_loss = F.mse_loss(values, curr_returns)
self.critic_optimizer.zero_grad()
critic_loss.backward()
nn.utils.clip_grad_norm_(self.critic.parameters(), gradient_clip_cr)
self.critic_optimizer.step()

# train actor
gradient_clip_ac = self.hyper_params["GRADIENT_CLIP_AC"]
actions = self.actor(states)
actor_loss = -self.critic(torch.cat((states, actions), dim=-1)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), gradient_clip_ac)
self.actor_optimizer.step()

# update target networks
Expand Down
5 changes: 5 additions & 0 deletions algorithms/fd/ddpg_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import torch
import torch.nn as nn

from algorithms.common.buffer.priortized_replay_buffer import PrioritizedReplayBufferfD
from algorithms.common.buffer.replay_buffer import NStepTransitionBuffer
Expand Down Expand Up @@ -104,6 +105,7 @@ def update_model(self) -> Tuple[torch.Tensor, torch.Tensor]:
gamma = self.hyper_params["GAMMA"]

# train critic
gradient_clip_cr = self.hyper_params["GRADIENT_CLIP_CR"]
critic_loss_element_wise = self._get_critic_loss(experiences_1, gamma)
critic_loss = torch.mean(critic_loss_element_wise * weights)

Expand All @@ -118,14 +120,17 @@ def update_model(self) -> Tuple[torch.Tensor, torch.Tensor]:

self.critic_optimizer.zero_grad()
critic_loss.backward()
nn.utils.clip_grad_norm_(self.critic.parameters(), gradient_clip_cr)
self.critic_optimizer.step()

# train actor
gradient_clip_ac = self.hyper_params["GRADIENT_CLIP_AC"]
actions = self.actor(states)
actor_loss_element_wise = -self.critic(torch.cat((states, actions), dim=-1))
actor_loss = torch.mean(actor_loss_element_wise * weights)
self.actor_optimizer.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), gradient_clip_ac)
self.actor_optimizer.step()

# update target networks
Expand Down
5 changes: 5 additions & 0 deletions algorithms/per/ddpg_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Tuple

import torch
import torch.nn as nn

from algorithms.common.buffer.priortized_replay_buffer import PrioritizedReplayBuffer
import algorithms.common.helper_functions as common_utils
Expand Down Expand Up @@ -53,19 +54,23 @@ def update_model(self) -> Tuple[torch.Tensor, ...]:
curr_returns = curr_returns.to(device).detach()

# train critic
gradient_clip_cr = self.hyper_params["GRADIENT_CLIP_CR"]
values = self.critic(torch.cat((states, actions), dim=-1))
critic_loss_element_wise = (values - curr_returns).pow(2)
critic_loss = torch.mean(critic_loss_element_wise * weights)
self.critic_optimizer.zero_grad()
critic_loss.backward()
nn.utils.clip_grad_norm_(self.critic.parameters(), gradient_clip_cr)
self.critic_optimizer.step()

# train actor
gradient_clip_ac = self.hyper_params["GRADIENT_CLIP_AC"]
actions = self.actor(states)
actor_loss_element_wise = -self.critic(torch.cat((states, actions), dim=-1))
actor_loss = torch.mean(actor_loss_element_wise * weights)
self.actor_optimizer.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), gradient_clip_ac)
self.actor_optimizer.step()

# update target networks
Expand Down
6 changes: 4 additions & 2 deletions examples/lunarlander_continuous_v2/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@
hyper_params = {
"GAMMA": 0.99,
"TAU": 5e-3,
"BUFFER_SIZE": int(1e5),
"BUFFER_SIZE": int(1e4),
"BATCH_SIZE": 64,
"LR_ACTOR": 4e-5,
"LR_ACTOR": 3e-4,
"LR_CRITIC": 3e-4,
"OU_NOISE_THETA": 0.0,
"OU_NOISE_SIGMA": 0.0,
"WEIGHT_DECAY": 1e-6,
"INITIAL_RANDOM_ACTION": int(1e4),
"MULTIPLE_LEARN": 1, # multiple learning updates
"GRADIENT_CLIP_AC": 0.5,
"GRADIENT_CLIP_CR": 1.0,
}


Expand Down
10 changes: 6 additions & 4 deletions examples/lunarlander_continuous_v2/ddpgfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,23 @@
"GAMMA": 0.99,
"TAU": 5e-3,
"BUFFER_SIZE": int(1e5),
"BATCH_SIZE": 64,
"LR_ACTOR": 4e-5,
"BATCH_SIZE": 128,
"LR_ACTOR": 3e-4,
"LR_CRITIC": 3e-4,
"OU_NOISE_THETA": 0.0,
"OU_NOISE_SIGMA": 0.0,
"PRETRAIN_STEP": int(1e3),
"PRETRAIN_STEP": int(5e3),
"MULTIPLE_LEARN": 1, # multiple learning updates
"LAMBDA1": 1.0, # N-step return weight
"LAMBDA2": 1e-5, # l2 regularization weight
"LAMBDA2": 1e-4, # l2 regularization weight
"LAMBDA3": 1.0, # actor loss contribution of prior weight
"PER_ALPHA": 0.3,
"PER_BETA": 1.0,
"PER_EPS": 1e-6,
"PER_EPS_DEMO": 1.0,
"INITIAL_RANDOM_ACTION": int(1e4),
"GRADIENT_CLIP_AC": 0.5,
"GRADIENT_CLIP_CR": 1.0,
}


Expand Down
8 changes: 5 additions & 3 deletions examples/lunarlander_continuous_v2/per-ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@
hyper_params = {
"GAMMA": 0.99,
"TAU": 5e-3,
"BUFFER_SIZE": int(1e5),
"BATCH_SIZE": 64,
"BUFFER_SIZE": int(1e6),
"BATCH_SIZE": 128,
"LR_ACTOR": 3e-4,
"LR_CRITIC": 3e-4,
"OU_NOISE_THETA": 0.0,
"OU_NOISE_SIGMA": 0.0,
"PER_ALPHA": 0.6,
"PER_BETA": 0.4,
"PER_EPS": 1e-6,
"WEIGHT_DECAY": 1e-6,
"WEIGHT_DECAY": 5e-6,
"INITIAL_RANDOM_ACTION": int(1e4),
"MULTIPLE_LEARN": 1, # multiple learning updates
"GRADIENT_CLIP_AC": 0.5,
"GRADIENT_CLIP_CR": 1.0,
}


Expand Down

0 comments on commit 4248057

Please sign in to comment.