From ced47af5374d6363c4f01f8cbb9ccadf29881560 Mon Sep 17 00:00:00 2001 From: zrz-sh <2290321870@qq.com> Date: Wed, 26 Jul 2023 15:18:33 +0800 Subject: [PATCH] add sac --- examples/sac/train_sac_beta.py | 51 ++++ openrl/algorithms/sac.py | 351 +++++++++++++++++++++++++ openrl/configs/config.py | 17 +- openrl/drivers/offpolicy_driver.py | 15 +- openrl/modules/common/__init__.py | 2 + openrl/modules/common/sac_net.py | 105 ++++++++ openrl/modules/networks/sac_network.py | 114 ++++++++ openrl/modules/sac_module.py | 219 +++++++++++++++ openrl/runners/common/__init__.py | 2 + openrl/runners/common/sac_agent.py | 142 ++++++++++ 10 files changed, 1002 insertions(+), 16 deletions(-) create mode 100644 examples/sac/train_sac_beta.py create mode 100644 openrl/algorithms/sac.py create mode 100644 openrl/modules/common/sac_net.py create mode 100644 openrl/modules/networks/sac_network.py create mode 100644 openrl/modules/sac_module.py create mode 100644 openrl/runners/common/sac_agent.py diff --git a/examples/sac/train_sac_beta.py b/examples/sac/train_sac_beta.py new file mode 100644 index 00000000..caa4ce4a --- /dev/null +++ b/examples/sac/train_sac_beta.py @@ -0,0 +1,51 @@ +"""""" +import numpy as np + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.modules.common import SACNet as Net +from openrl.runners.common import SACAgent as Agent + + +def train(): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args() + + # create environment, set environment parallelism + env = make("Pendulum-v1", env_num=9) + + # create the neural network + net = Net(env, cfg=cfg) + # initialize the trainer + agent = Agent(net) + # start training, set total number of training steps + agent.train(total_time_steps=20000) + + env.close() + return agent + + +def evaluation(agent): + # begin to test + # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. + render_mode = None + env = make("Pendulum-v1", render_mode=render_mode, env_num=9, asynchronous=True) + # The trained agent sets up the interactive environment it needs. + agent.set_env(env) + # Initialize the environment and get initial observations and environmental information. + obs, info = env.reset() + done = False + step = 0 + while not np.any(done): + # Based on environmental observation input, predict next action. + action = agent.act(obs, sample=False) # sample=False in evaluation + obs, r, done, info = env.step(action) + step += 1 + if step % 50 == 0: + print(f"{step}: reward:{np.mean(r)}") + env.close() + + +if __name__ == "__main__": + agent = train() + evaluation(agent) diff --git a/openrl/algorithms/sac.py b/openrl/algorithms/sac.py new file mode 100644 index 00000000..502f302e --- /dev/null +++ b/openrl/algorithms/sac.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +from typing import Union + +import torch +import torch.nn.functional as F + +from openrl.algorithms.base_algorithm import BaseAlgorithm +from openrl.modules.networks.utils.distributed_utils import reduce_tensor +from openrl.modules.utils.util import get_gard_norm +from openrl.utils.util import check + + +class SACAlgorithm(BaseAlgorithm): + def __init__( + self, + cfg, + init_module, + agent_num: int = 1, + device: Union[str, torch.device] = "cpu", + ) -> None: + super().__init__(cfg, init_module, agent_num, device) + + self.gamma = cfg.gamma + self.tau = cfg.tau + self.target_entropy = self.algo_module.target_entropy + + def prepare_critic_loss( + self, + obs_batch, + next_obs_batch, + rnn_states_batch, + actions_batch, + masks_batch, + action_masks_batch, + value_preds_batch, + rewards_batch, + active_masks_batch, + turn_on, + ): + ( + target_q_values, + target_q_values_2, + current_q_values, + current_q_values_2, + next_log_prob, + ) = self.algo_module.get_q_values( + obs_batch, + next_obs_batch, + rnn_states_batch, + rewards_batch, + actions_batch, + masks_batch, + action_masks_batch, + active_masks_batch, + ) + + with torch.no_grad(): + next_q_values = ( + torch.min(target_q_values, target_q_values_2) + - torch.exp(self.algo_module.log_alpha) * next_log_prob + ) + q_target = ( + rewards_batch + self.gamma * torch.tensor(masks_batch) * next_q_values + ) + + critic_loss = F.mse_loss(current_q_values, q_target) + critic_loss_2 = F.mse_loss(current_q_values_2, q_target) + + return critic_loss, critic_loss_2 + + def prepare_actor_loss( + self, + obs_batch, + next_obs_batch, + rnn_states_batch, + actions_batch, + masks_batch, + action_masks_batch, + value_preds_batch, + rewards_batch, + active_masks_batch, + turn_on, + ): + actor_loss, log_prob = self.algo_module.evaluate_actor_loss( + obs_batch, + next_obs_batch, + rnn_states_batch, + rewards_batch, + actions_batch, + masks_batch, + action_masks_batch, + active_masks_batch, + ) + + return actor_loss, log_prob + + def prepare_alpha_loss(self, log_prob): + alpha_loss = -( + torch.exp(self.algo_module.log_alpha) + * (log_prob + self.target_entropy).detach() + ).mean() + + return alpha_loss + + def sac_update(self, sample, turn_on=True): + ( + obs_batch, + _, + next_obs_batch, + _, + rnn_states_batch, + rnn_states_critic_batch, + actions_batch, + value_preds_batch, + rewards_batch, + masks_batch, + active_masks_batch, + old_action_log_probs_batch, + adv_targ, + action_masks_batch, + ) = sample + + value_preds_batch = check(value_preds_batch).to(**self.tpdv) + rewards_batch = check(rewards_batch).to(**self.tpdv) + active_masks_batch = check(active_masks_batch).to(**self.tpdv) + + # update critic network + self.algo_module.optimizers["critic"].zero_grad() + self.algo_module.optimizers["critic_2"].zero_grad() + + if self.use_amp: + with torch.cuda.amp.autocast(): + critic_loss, critic_loss_2 = self.prepare_critic_loss( + obs_batch, + next_obs_batch, + rnn_states_batch, + actions_batch, + masks_batch, + action_masks_batch, + value_preds_batch, + rewards_batch, + active_masks_batch, + turn_on, + ) + critic_loss.backward() + critic_loss_2.backward() + else: + critic_loss, critic_loss_2 = self.prepare_critic_loss( + obs_batch, + next_obs_batch, + rnn_states_batch, + actions_batch, + masks_batch, + action_masks_batch, + value_preds_batch, + rewards_batch, + active_masks_batch, + turn_on, + ) + critic_loss.backward() + critic_loss_2.backward() + + if "transformer" in self.algo_module.models: + raise NotImplementedError + else: + critic_para = self.algo_module.models["critic"].parameters() + critic_para_2 = self.algo_module.models["critic_2"].parameters() + critic_grad_norm = get_gard_norm(critic_para) + critic_grad_norm_2 = get_gard_norm(critic_para_2) + + if self.use_amp: + raise NotImplementedError + # self.algo_module.scaler.unscale_(self.algo_module.optimizers["critic"]) + # self.algo_module.scaler.step(self.algo_module.optimizers["critic"]) + # self.algo_module.scaler.update() + else: + self.algo_module.optimizers["critic"].step() + self.algo_module.optimizers["critic_2"].step() + + # update actor network + self.algo_module.optimizers["actor"].zero_grad() + + if self.use_amp: + with torch.cuda.amp.autocast(): + actor_loss, log_prob = self.prepare_actor_loss( + obs_batch, + next_obs_batch, + rnn_states_batch, + actions_batch, + masks_batch, + action_masks_batch, + value_preds_batch, + rewards_batch, + active_masks_batch, + turn_on, + ) + actor_loss.backward() + else: + actor_loss, log_prob = self.prepare_actor_loss( + obs_batch, + next_obs_batch, + rnn_states_batch, + actions_batch, + masks_batch, + action_masks_batch, + value_preds_batch, + rewards_batch, + active_masks_batch, + turn_on, + ) + actor_loss.backward() + + if "transformer" in self.algo_module.models: + raise NotImplementedError + else: + actor_para = self.algo_module.models["actor"].parameters() + actor_grad_norm = get_gard_norm(actor_para) + + if self.use_amp: + self.algo_module.scaler.unscale_(self.algo_module.optimizers["actor"]) + self.algo_module.scaler.step(self.algo_module.optimizers["actor"]) + self.algo_module.scaler.update() + else: + self.algo_module.optimizers["actor"].step() + + # update target network + for param, target_param in zip( + self.algo_module.models["critic"].parameters(), + self.algo_module.models["critic_target"].parameters(), + ): + target_param.data.copy_( + self.tau * param.data + (1 - self.tau) * target_param.data + ) + + for param, target_param in zip( + self.algo_module.models["critic_2"].parameters(), + self.algo_module.models["critic_target_2"].parameters(), + ): + target_param.data.copy_( + self.tau * param.data + (1 - self.tau) * target_param.data + ) + + # update alpha + self.algo_module.optimizers["alpha"].zero_grad() + + if self.use_amp: + raise NotImplementedError + else: + alpha_loss = self.prepare_alpha_loss(log_prob) + alpha_loss.backward() + self.algo_module.optimizers["alpha"].step() + + # for others + if self.world_size > 1: + torch.cuda.synchronize() + + loss_list = [] + loss_list.append(critic_loss) + loss_list.append(actor_loss) + loss_list.append(alpha_loss) + + return loss_list + + def cal_value_loss( + self, + value_normalizer, + values, + value_preds_batch, + return_batch, + active_masks_batch, + ): + # TODO:to be finished + raise NotImplementedError( + "The calc_value_loss function in sac.py has not implemented yet" + ) + + def to_single_np(self, input): + reshape_input = input.reshape(-1, self.agent_num, *input.shape[1:]) + return reshape_input[:, 0, ...] + + def train(self, buffer, turn_on=True): + train_info = {} + + train_info["critic_loss"] = 0 + train_info["actor_loss"] = 0 + train_info["alpha_loss"] = 0 + if self.world_size > 1: + train_info["reduced_critic_loss"] = 0 + train_info["reduced_actor_loss"] = 0 + train_info["reduced_alpha_loss"] = 0 + + # todo add rnn and transformer + + for _ in range(self.num_mini_batch): + if "transformer" in self.algo_module.models: + raise NotImplementedError + elif self._use_recurrent_policy: + raise NotImplementedError + elif self._use_naive_recurrent: + raise NotImplementedError + else: + data_generator = buffer.feed_forward_generator( + None, + num_mini_batch=self.num_mini_batch, + mini_batch_size=self.mini_batch_size, + ) + + for sample in data_generator: + loss_list = self.sac_update(sample, turn_on) + if self.world_size > 1: + train_info["reduced_critic_loss"] += reduce_tensor( + loss_list[0].data, self.world_size + ) + train_info["reduced_actor_loss"] += reduce_tensor( + loss_list[1].data, self.world_size + ) + train_info["reduced_alpha_loss"] += reduce_tensor( + loss_list[2].data, self.world_size + ) + + train_info["critic_loss"] += loss_list[0].item() + train_info["actor_loss"] += loss_list[1].item() + train_info["alpha_loss"] += loss_list[2].item() + + num_updates = 1 * self.num_mini_batch + + for k in train_info.keys(): + train_info[k] /= num_updates + + for optimizer in self.algo_module.optimizers.values(): + if hasattr(optimizer, "sync_lookahead"): + optimizer.sync_lookahead() + + return train_info diff --git a/openrl/configs/config.py b/openrl/configs/config.py index e89cbc18..a7572fc6 100644 --- a/openrl/configs/config.py +++ b/openrl/configs/config.py @@ -983,7 +983,7 @@ def create_config_parser(): "After how many evaluation network updates target network should be updated" ), ) - ## for DDPG + # for DDPG parser.add_argument( "--var", type=int, @@ -991,14 +991,15 @@ def create_config_parser(): help="Control the exploration variance of the generated actions", ) parser.add_argument( - "actor_lr", type=float, default=0.001, help="The learning rate of actor network" + "actor_lr", type=float, default=5e-4, help="The learning rate of actor network" + ) + # for SAC + parser.add_argument( + "alpha_lr", + type=float, + default=2e-4, + help="The learning rate of temperature alpha", ) - # parser.add_argument( - # "critic_lr", - # type=float, - # default=0.002, - # help="The learning rate of critic network", - # ) # update parameters parser.add_argument( "--use_soft_update", diff --git a/openrl/drivers/offpolicy_driver.py b/openrl/drivers/offpolicy_driver.py index dc94d4ed..98d73d60 100644 --- a/openrl/drivers/offpolicy_driver.py +++ b/openrl/drivers/offpolicy_driver.py @@ -90,9 +90,8 @@ def _inner_loop( if self.episode % self.log_interval == 0: # rollout_infos can only be used when env is wrapped with VevMonitor - # self.logger.log_info(rollout_infos, step=self.total_num_steps) - # self.logger.log_info(train_infos, step=self.total_num_steps) - pass + self.logger.log_info(rollout_infos, step=self.total_num_steps) + self.logger.log_info(train_infos, step=self.total_num_steps) return True @@ -172,7 +171,7 @@ def actor_rollout(self): next_obs, rewards, dones, infos = self.envs.step(actions, extra_data) # print("rewards: ", rewards) - elif self.algorithm_name == "DDPG": + elif self.algorithm_name == "DDPG" or "SAC": actions = self.act(step) extra_data = { "step": step, @@ -188,9 +187,9 @@ def actor_rollout(self): # counter += 1 if any(dones): - next_obs = np.array( - [infos[i]["final_observation"] for i in range(len(infos))] - ) + for i in range(len(infos)): + if all(dones[i]): + next_obs[i] = infos[i]["final_observation"] # print("运行次数为:%d, 回报为:%.3f, 探索方差为:%.4f" % (counter, ep_reward, self.var)) # counter = 0 # ep_reward = 0 @@ -290,7 +289,7 @@ def act( rnn_states, ) - elif self.algorithm_name == "DDPG": + elif self.algorithm_name == "DDPG" or self.algorithm_name == "SAC": actions = self.trainer.algo_module.get_actions( self.buffer.data.get_batch_data( "next_policy_obs" if step != 0 else "policy_obs", step diff --git a/openrl/modules/common/__init__.py b/openrl/modules/common/__init__.py index 46c3e3da..3a41122f 100644 --- a/openrl/modules/common/__init__.py +++ b/openrl/modules/common/__init__.py @@ -5,6 +5,7 @@ from .gail_net import GAILNet from .mat_net import MATNet from .ppo_net import PPONet +from .sac_net import SACNet from .vdn_net import VDNNet __all__ = [ @@ -16,4 +17,5 @@ "VDNNet", "GAILNet", "BCNet", + "SACNet", ] diff --git a/openrl/modules/common/sac_net.py b/openrl/modules/common/sac_net.py new file mode 100644 index 00000000..4d90f02e --- /dev/null +++ b/openrl/modules/common/sac_net.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +from typing import Any, Dict, Optional, Tuple, Union + +import gymnasium as gym +import numpy as np +import torch + +from openrl.configs.config import create_config_parser +from openrl.modules.base_module import BaseModule +from openrl.modules.common.base_net import BaseNet +from openrl.modules.sac_module import SACModule +from openrl.utils.util import set_seed + + +class SACNet(BaseNet): + def __init__( + self, + env: Union[gym.Env, str], + cfg=None, + device: Union[torch.device, str] = "cpu", + n_rollout_threads: int = 1, + model_dict: Optional[Dict[str, Any]] = None, + module_class: BaseModule = SACModule, + ) -> None: + super().__init__() + + if cfg is None: + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args() + + set_seed(cfg.seed) + env.reset(seed=cfg.seed) + + cfg.num_agents = env.agent_num + cfg.n_rollout_threads = n_rollout_threads + cfg.learner_n_rollout_threads = cfg.n_rollout_threads + cfg.algorithm_name = "SAC" + + if cfg.rnn_type == "gru": + rnn_hidden_size = cfg.hidden_size + elif cfg.rnn_type == "lstm": + rnn_hidden_size = cfg.hidden_size * 2 + else: + raise NotImplementedError( + f"RNN type {cfg.rnn_type} has not been implemented." + ) + cfg.rnn_hidden_size = rnn_hidden_size + + if isinstance(device, str): + device = torch.device(device) + + self.module = module_class( + cfg=cfg, + input_space=env.observation_space, + act_space=env.action_space, + device=device, + rank=0, + world_size=1, + model_dict=model_dict, + ) + + self.cfg = cfg + self.env = env + self.device = device + self.rnn_states_actor = None + self.masks = None + + def act( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + sample=True, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + actions = self.module.act(observation, sample=sample).detach().numpy() + return actions + + def reset(self, env: Optional[gym.Env] = None) -> None: + if env is not None: + self.env = env + self.first_reset = False + self.rnn_states_actor, self.masks = self.module.init_rnn_states( + rollout_num=self.env.parallel_env_num, + agent_num=self.env.agent_num, + rnn_layers=self.cfg.recurrent_N, + hidden_size=self.cfg.rnn_hidden_size, + ) + + def load_policy(self, path: str) -> None: + self.module.load_policy(path) diff --git a/openrl/modules/networks/sac_network.py b/openrl/modules/networks/sac_network.py new file mode 100644 index 00000000..2b23b8e8 --- /dev/null +++ b/openrl/modules/networks/sac_network.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2021 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from openrl.buffers.utils.util import get_critic_obs_space, get_policy_obs_space +from openrl.modules.networks.base_policy_network import BasePolicyNetwork +from openrl.modules.networks.base_value_network import BaseValueNetwork +from openrl.modules.networks.ddpg_network import ActorNetwork +from openrl.modules.networks.utils.cnn import CNNBase +from openrl.modules.networks.utils.mix import MIXBase +from openrl.modules.networks.utils.mlp import MLPBase +from openrl.modules.networks.utils.rnn import RNNLayer +from openrl.modules.networks.utils.util import init +from openrl.utils.util import check_v2 as check + + +class SACActorNetwork(ActorNetwork): + def __init__( + self, + cfg, + input_space, + action_space, + device=torch.device("cpu"), + use_half=False, + extra_args=None, + log_std_min=-20, + log_std_max=2, + ) -> None: + super().__init__(cfg, input_space, action_space, device, use_half, extra_args) + + init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][ + self._use_orthogonal + ] + input_size = self.base.output_size + + def init_(m): + return init(m, init_method, lambda x: nn.init.constant_(x, 0)) + + if isinstance(self.action_space, gym.spaces.box.Box): + self.actor_out = init_(nn.Linear(input_size, action_space.shape[0] * 2)) + else: + raise NotImplementedError(f"This type of game has not been implemented.") + + self.log_std_min = log_std_min + self.log_std_max = log_std_max + + def forward(self, obs): + if self._mixed_obs: + for key in obs.keys(): + obs[key] = check(obs[key]).to(**self.tpdv) + else: + obs = check(obs).to(**self.tpdv) + + features = self.base(obs) + + if isinstance(self.action_space, gym.spaces.box.Box): + output = self.actor_out(features) + mean, log_std = output.chunk(2, dim=-1) + log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + else: + raise NotImplementedError("This type of game has not been implemented.") + + return mean, log_std + + def _normalize(self, action) -> torch.Tensor: + """ + Normalize the action value to the action space range. + the return values of self.fcs is between -1 and 1 since we use tanh as output activation, while we want the action ranges to be (self.action_space.low, self.action_space.high). + """ + action = (action + 1) / 2 * ( + torch.tensor(self.action_space.high) - torch.tensor(self.action_space.low) + ) + torch.tensor(self.action_space.low) + return action + + def evaluate(self, obs, sample=True): + mean, log_std = self.forward(obs) + if not sample: + action = torch.tanh(mean) # add tanh to activate + return self._normalize(action), None + + # sample action from N(mean, std) if sample is True + # obtain log_prob for policy and Q function update + # use the reparameterization trick, and perform tanh normalization + + std = torch.exp(log_std) + dist = torch.distributions.Normal(mean, std) + action = dist.rsample() + log_prob = dist.log_prob(action).sum(axis=-1) + log_prob -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum( + axis=-1 + ) # NOTE: The correction formula from the original SAC paper (arXiv 1801.01290) appendix C + action = torch.tanh(action) # add tanh to activate + + return self._normalize(action), log_prob.unsqueeze(dim=-1) diff --git a/openrl/modules/sac_module.py b/openrl/modules/sac_module.py new file mode 100644 index 00000000..5700f2bc --- /dev/null +++ b/openrl/modules/sac_module.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +from typing import Any, Dict, Optional, Union + +import gym +import numpy as np +import torch + +from openrl.modules.model_config import ModelTrainConfig +from openrl.modules.networks.ddpg_network import CriticNetwork +from openrl.modules.networks.sac_network import SACActorNetwork +from openrl.modules.rl_module import RLModule +from openrl.modules.utils.util import update_linear_schedule + + +class SACModule(RLModule): + def __init__( + self, + cfg, + input_space: gym.spaces.Box, + act_space: gym.spaces.Box, + device: Union[str, torch.device] = "cpu", + rank: Optional[int] = None, + world_size: Optional[int] = None, + model_dict: Optional[Dict[str, Any]] = None, + ): + model_configs = {} + model_configs["actor"] = ModelTrainConfig( + lr=cfg.actor_lr, + model=( + model_dict["actor"] + if model_dict and "actor" in model_dict + else SACActorNetwork + ), + input_space=input_space, + ) + model_configs["critic"] = ModelTrainConfig( + lr=cfg.critic_lr, + model=( + model_dict["critic"] + if model_dict and "critic" in model_dict + else CriticNetwork + ), + input_space=input_space, + ) + model_configs["critic_target"] = ModelTrainConfig( + lr=cfg.critic_lr, + model=( + model_dict["critic_target"] + if model_dict and "critic_target" in model_dict + else CriticNetwork + ), + input_space=input_space, + ) + model_configs["critic_2"] = ModelTrainConfig( + lr=cfg.critic_lr, + model=( + model_dict["critic_2"] + if model_dict and "critic_2" in model_dict + else CriticNetwork + ), + input_space=input_space, + ) + model_configs["critic_target_2"] = ModelTrainConfig( + lr=cfg.critic_lr, + model=( + model_dict["critic_target_2"] + if model_dict and "critic_target_2" in model_dict + else CriticNetwork + ), + input_space=input_space, + ) + + super().__init__( + cfg=cfg, + model_configs=model_configs, + act_space=act_space, + rank=rank, + world_size=world_size, + device=device, + ) + self.obs_space = input_space + self.act_space = act_space + self.cfg = cfg + + # alpha (can be dynamically adjusted) + self.log_alpha = torch.zeros(1, requires_grad=True, device=device) + alpha_optimizer = torch.optim.Adam( + [self.log_alpha], + lr=cfg.alpha_lr, + eps=cfg.opti_eps, + weight_decay=cfg.weight_decay, + ) + self.optimizers["alpha"] = alpha_optimizer + self.target_entropy = -np.prod(act_space.shape).item() + + def lr_decay(self, episode, episodes): + update_linear_schedule( + self.optimizers["critic"], episode, episodes, self.cfg.critic_lr + ) + update_linear_schedule( + self.optimizers["critic_2"], episode, episodes, self.cfg.critic_lr + ) + update_linear_schedule( + self.optimizers["actor"], episode, episodes, self.cfg.actor_lr + ) + update_linear_schedule( + self.optimizers["alpha"], episode, episodes, self.cfg.alpha_lr + ) + + def get_actions(self, obs, sample=True): + actions, _ = self.models["actor"].evaluate(obs, sample=sample) + + return actions + + def get_values(self, obs, action, rnn_states_critic, masks): + critic_values, _ = self.models["critic"](obs, action, rnn_states_critic, masks) + + return critic_values + + def evaluate_actor_loss( + self, + obs_batch, + next_obs_batch, + rnn_states_batch, + rewards_batch, + actions_batch, + masks, + action_masks=None, + masks_batch=None, + ): + if masks_batch is None: + masks_batch = masks + + action, log_prob = self.models["actor"].evaluate(obs_batch) + q_values = torch.min( + self.models["critic"](obs_batch, action, rnn_states_batch, masks_batch)[0], + self.models["critic_2"](obs_batch, action, rnn_states_batch, masks_batch)[ + 0 + ], + ) + actor_loss = (torch.exp(self.log_alpha) * log_prob - q_values).mean() + + return actor_loss, log_prob + + def get_q_values( + self, + obs_batch, + next_obs_batch, + rnn_states_batch, + rewards_batch, + actions_batch, + masks, + action_masks=None, + masks_batch=None, + ): + if masks_batch is None: + masks_batch = masks + + current_q_values, _ = self.models["critic"]( + obs_batch, actions_batch, rnn_states_batch, masks_batch + ) + + current_q_values_2, _ = self.models["critic_2"]( + obs_batch, actions_batch, rnn_states_batch, masks_batch + ) + + with torch.no_grad(): + next_action, next_log_prob = self.models["actor"].evaluate(next_obs_batch) + target_q_values, _ = self.models["critic_target"]( + next_obs_batch, next_action, rnn_states_batch, masks_batch + ) + target_q_values_2, _ = self.models["critic_target_2"]( + next_obs_batch, next_action, rnn_states_batch, masks_batch + ) + + return ( + target_q_values, + target_q_values_2, + current_q_values, + current_q_values_2, + next_log_prob, + ) + + def evaluate_actions(self): + # This function is not required in SAC + pass + + def act(self, obs, sample=True): + actions, _ = self.models["actor"].evaluate(obs, sample=sample) + + return actions + + def get_critic_value_normalizer(self): + return self.models["critic"].value_normalizer + + @staticmethod + def init_rnn_states( + rollout_num: int, agent_num: int, rnn_layers: int, hidden_size: int + ): + masks = np.ones((rollout_num * agent_num, 1), dtype=np.float32) + rnn_state = np.zeros((rollout_num * agent_num, rnn_layers, hidden_size)) + return rnn_state, masks diff --git a/openrl/runners/common/__init__.py b/openrl/runners/common/__init__.py index 1ff5381e..d315ecf3 100644 --- a/openrl/runners/common/__init__.py +++ b/openrl/runners/common/__init__.py @@ -5,6 +5,7 @@ from openrl.runners.common.gail_agent import GAILAgent from openrl.runners.common.mat_agent import MATAgent from openrl.runners.common.ppo_agent import PPOAgent +from openrl.runners.common.sac_agent import SACAgent from openrl.runners.common.vdn_agent import VDNAgent __all__ = [ @@ -17,4 +18,5 @@ "VDNAgent", "GAILAgent", "BCAgent", + "SACAgent", ] diff --git a/openrl/runners/common/sac_agent.py b/openrl/runners/common/sac_agent.py new file mode 100644 index 00000000..b81fdd56 --- /dev/null +++ b/openrl/runners/common/sac_agent.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import Dict, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch + +from openrl.algorithms.base_algorithm import BaseAlgorithm +from openrl.algorithms.sac import SACAlgorithm as TrainAlgo +from openrl.buffers import OffPolicyReplayBuffer as ReplayBuffer +from openrl.buffers.utils.obs_data import ObsData +from openrl.drivers.base_driver import BaseDriver +from openrl.drivers.offpolicy_driver import OffPolicyDriver as Driver +from openrl.runners.common.base_agent import SelfAgent +from openrl.runners.common.rl_agent import RLAgent +from openrl.utils.logger import Logger +from openrl.utils.type_aliases import MaybeCallback + + +class SACAgent(RLAgent): + def __init__( + self, + net: Optional[torch.nn.Module] = None, + env: Union[gym.Env, str] = None, + run_dir: Optional[str] = None, + env_num: Optional[int] = None, + rank: int = 0, + world_size: int = 1, + use_wandb: bool = False, + use_tensorboard: bool = False, + project_name: str = "SACAgent", + ) -> None: + super(SACAgent, self).__init__( + net, + env, + run_dir, + env_num, + rank, + world_size, + use_wandb, + use_tensorboard, + project_name=project_name, + ) + + def train( + self: SelfAgent, + total_time_steps: int, + callback: MaybeCallback = None, + train_algo_class: Type[BaseAlgorithm] = TrainAlgo, + logger: Optional[Logger] = None, + DriverClass: Type[BaseDriver] = Driver, + ) -> None: + self._cfg.num_env_steps = total_time_steps + + self.config = { + "cfg": self._cfg, + "num_agents": self.agent_num, + "run_dir": self.run_dir, + "envs": self._env, + "device": self.net.device, + } + + trainer = train_algo_class( + cfg=self._cfg, + init_module=self.net.module, + device=self.net.device, + agent_num=self.agent_num, + ) + + buffer = ReplayBuffer( + self._cfg, + self.agent_num, + self._env.observation_space, + self._env.action_space, + data_client=None, + episode_length=self._cfg.episode_length, + ) + + if logger is None: + logger = Logger( + cfg=self._cfg, + project_name=self.project_name, + scenario_name=self._env.env_name, + wandb_entity=self._cfg.wandb_entity, + exp_name=self.exp_name, + log_path=self.run_dir, + use_wandb=self._use_wandb, + use_tensorboard=self._use_tensorboard, + ) + self._logger = logger + + total_time_steps, callback = self._setup_train( + total_time_steps, + callback, + reset_num_time_steps=True, + progress_bar=False, + ) + + driver = DriverClass( + config=self.config, + trainer=trainer, + buffer=buffer, + agent=self, + client=self.client, + rank=self.rank, + world_size=self.world_size, + logger=logger, + callback=callback, + ) + + callback.on_training_start(locals(), globals()) + + driver.run() + + callback.on_training_end() + + def act( + self, observation: Union[np.ndarray, Dict[str, np.ndarray]], sample=True + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + assert self.net is not None, "net is None" + observation = ObsData.prepare_input(observation) + + action = self.net.act(observation, sample=sample) + action = np.array(np.split(action, self.env_num)) + + return action