diff --git a/openrl/algorithms/dqn.py b/openrl/algorithms/dqn.py new file mode 100644 index 00000000..9f0a1be6 --- /dev/null +++ b/openrl/algorithms/dqn.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel + +from openrl.algorithms.base_algorithm import BaseAlgorithm +from openrl.modules.networks.utils.distributed_utils import reduce_tensor +from openrl.modules.utils.util import get_gard_norm, huber_loss, mse_loss +from openrl.utils.util import check + + +class DQNAlgorithm(BaseAlgorithm): + def __init__( + self, + cfg, + init_module, + agent_num: int = 1, + device: Union[str, torch.device] = "cpu", + ) -> None: + self._use_share_model = cfg.use_share_model + self.use_joint_action_loss = cfg.use_joint_action_loss + super(DQNAlgorithm, self).__init__(cfg, init_module, agent_num, device) + + def dqn_update(self, sample, turn_on=True): + for optimizer in self.algo_module.optimizers.values(): + optimizer.zero_grad() + + ( + obs_batch, + rnn_states_batch, + actions_batch, + value_preds_batch, + return_batch, + masks_batch, + active_masks_batch, + available_actions_batch, + ) = sample + + value_preds_batch = check(value_preds_batch).to(**self.tpdv) + return_batch = check(return_batch).to(**self.tpdv) + active_masks_batch = check(active_masks_batch).to(**self.tpdv) + + if self.use_amp: + with torch.cuda.amp.autocast(): + ( + loss_list, + value_loss, + policy_loss, + dist_entropy, + ratio, + ) = self.prepare_loss( + obs_batch, + rnn_states_batch, + actions_batch, + masks_batch, + available_actions_batch, + value_preds_batch, + return_batch, + active_masks_batch, + turn_on, + ) + for loss in loss_list: + self.algo_module.scaler.scale(loss).backward() + else: + loss_list, value_loss, policy_loss, dist_entropy, ratio = self.prepare_loss( + obs_batch, + rnn_states_batch, + actions_batch, + masks_batch, + available_actions_batch, + value_preds_batch, + return_batch, + active_masks_batch, + turn_on, + ) + for loss in loss_list: + loss.backward() + + if "transformer" in self.algo_module.models: + if self._use_max_grad_norm: + grad_norm = nn.utils.clip_grad_norm_( + self.algo_module.models["transformer"].parameters(), + self.max_grad_norm, + ) + else: + grad_norm = get_gard_norm( + self.algo_module.models["transformer"].parameters() + ) + critic_grad_norm = grad_norm + actor_grad_norm = grad_norm + + else: + if self._use_share_model: + actor_para = self.algo_module.models["model"].get_actor_para() + else: + actor_para = self.algo_module.models["policy"].parameters() + + if self._use_max_grad_norm: + actor_grad_norm = nn.utils.clip_grad_norm_( + actor_para, self.max_grad_norm + ) + else: + actor_grad_norm = get_gard_norm(actor_para) + + if self._use_share_model: + critic_para = self.algo_module.models["model"].get_critic_para() + else: + critic_para = self.algo_module.models["critic"].parameters() + + if self._use_max_grad_norm: + critic_grad_norm = nn.utils.clip_grad_norm_( + critic_para, self.max_grad_norm + ) + else: + critic_grad_norm = get_gard_norm(critic_para) + + if self.use_amp: + for optimizer in self.algo_module.optimizers.values(): + self.algo_module.scaler.unscale_(optimizer) + + for optimizer in self.algo_module.optimizers.values(): + self.algo_module.scaler.step(optimizer) + + self.algo_module.scaler.update() + else: + for optimizer in self.algo_module.optimizers.values(): + optimizer.step() + + if self.world_size > 1: + torch.cuda.synchronize() + + return ( + value_loss, + critic_grad_norm, + policy_loss, + dist_entropy, + actor_grad_norm, + ratio, + ) + + def cal_value_loss( + self, + value_normalizer, + values, + value_preds_batch, + return_batch, + active_masks_batch, + ): + value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp( + -self.clip_param, self.clip_param + ) + + if self._use_popart or self._use_valuenorm: + value_normalizer.update(return_batch) + error_clipped = ( + value_normalizer.normalize(return_batch) - value_pred_clipped + ) + error_original = value_normalizer.normalize(return_batch) - values + else: + error_clipped = return_batch - value_pred_clipped + error_original = return_batch - values + + if self._use_huber_loss: + value_loss_clipped = huber_loss(error_clipped, self.huber_delta) + value_loss_original = huber_loss(error_original, self.huber_delta) + else: + value_loss_clipped = mse_loss(error_clipped) + value_loss_original = mse_loss(error_original) + + if self._use_clipped_value_loss: + value_loss = torch.max(value_loss_original, value_loss_clipped) + else: + value_loss = value_loss_original + + if self._use_value_active_masks: + value_loss = ( + value_loss * active_masks_batch + ).sum() / active_masks_batch.sum() + else: + value_loss = value_loss.mean() + + return value_loss + + def to_single_np(self, input): + reshape_input = input.reshape(-1, self.agent_num, *input.shape[1:]) + return reshape_input[:, 0, ...] + + def prepare_loss( + self, + obs_batch, + rnn_states_batch, + actions_batch, + masks_batch, + available_actions_batch, + value_preds_batch, + return_batch, + active_masks_batch, + turn_on, + ): + raise NotImplementedError + + def train(self, buffer, turn_on=True): + raise NotImplementedError diff --git a/openrl/buffers/normal_buffer.py b/openrl/buffers/normal_buffer.py index 5716c27a..f431573f 100644 --- a/openrl/buffers/normal_buffer.py +++ b/openrl/buffers/normal_buffer.py @@ -106,3 +106,6 @@ def recurrent_generator(self, advantages, num_mini_batch, data_chunk_length): return self.data.recurrent_generator( advantages, num_mini_batch, data_chunk_length ) + + def get_buffer_size(self): + return self.data.critic_obs.shape[0] \ No newline at end of file diff --git a/openrl/drivers/offpolicy_driver.py b/openrl/drivers/offpolicy_driver.py new file mode 100644 index 00000000..6f4ca603 --- /dev/null +++ b/openrl/drivers/offpolicy_driver.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import Any, Dict, Optional + +import numpy as np +import torch +from torch.nn.parallel import DistributedDataParallel + +from openrl.drivers.rl_driver import RLDriver +from openrl.utils.logger import Logger +from openrl.utils.util import _t2n + + +class OffPolicyDriver(RLDriver): + def __init__( + self, + config: Dict[str, Any], + trainer, + buffer, + rank: int = 0, + world_size: int = 1, + client=None, + logger: Optional[Logger] = None, + ) -> None: + super(OffPolicyDriver, self).__init__(config, trainer, buffer, rank, world_size, client, logger) + + self.buffer_minimal_size = config["cfg"].buffer_size * 0.1 + + def _inner_loop( + self, + ) -> None: + rollout_infos = self.actor_rollout() + + if self.buffer.get_buffer_size() > self.buffer_minimal_size: + train_infos = self.learner_update() + self.buffer.after_update() + else: + train_infos = {'value_loss': 0, + 'policy_loss': 0, + 'dist_entropy': 0, + 'actor_grad_norm': 0, + 'critic_grad_norm': 0, + 'ratio': 0} + + self.total_num_steps = ( + (self.episode + 1) * self.episode_length * self.n_rollout_threads + ) + + if self.episode % self.log_interval == 0: + # rollout_infos can only be used when env is wrapped with VevMonitor + self.logger.log_info(rollout_infos, step=self.total_num_steps) + self.logger.log_info(train_infos, step=self.total_num_steps) + + def add2buffer(self, data): + ( + obs, + rewards, + dones, + infos, + values, + actions, + action_log_probs, + rnn_states, + rnn_states_critic, + ) = data + + rnn_states[dones] = np.zeros( + (dones.sum(), self.recurrent_N, self.hidden_size), + dtype=np.float32, + ) + + rnn_states_critic[dones] = np.zeros( + (dones.sum(), *self.buffer.data.rnn_states_critic.shape[3:]), + dtype=np.float32, + ) + masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32) + masks[dones] = np.zeros((dones.sum(), 1), dtype=np.float32) + + self.buffer.insert( + obs, + rnn_states, + rnn_states_critic, + actions, + action_log_probs, + values, + rewards, + masks, + ) + + def actor_rollout(self): + self.trainer.prep_rollout() + import time + + for step in range(self.episode_length): + values, actions, action_log_probs, rnn_states, rnn_states_critic = self.act( + step + ) + + extra_data = { + "values": values, + "action_log_probs": action_log_probs, + "step": step, + "buffer": self.buffer, + } + + obs, rewards, dones, infos = self.envs.step(actions, extra_data) + + data = ( + obs, + rewards, + dones, + infos, + values, + actions, + action_log_probs, + rnn_states, + rnn_states_critic, + ) + + self.add2buffer(data) + + batch_rew_infos = self.envs.batch_rewards(self.buffer) + + if self.envs.use_monitor: + statistics_info = self.envs.statistics(self.buffer) + statistics_info.update(batch_rew_infos) + return statistics_info + else: + return batch_rew_infos + + @torch.no_grad() + def compute_returns(self): + self.trainer.prep_rollout() + + next_values = self.trainer.algo_module.get_values( + self.buffer.data.get_batch_data("critic_obs", -1), + np.concatenate(self.buffer.data.rnn_states_critic[-1]), + np.concatenate(self.buffer.data.masks[-1]), + ) + + next_values = np.array( + np.split(_t2n(next_values), self.learner_n_rollout_threads) + ) + if "critic" in self.trainer.algo_module.models and isinstance( + self.trainer.algo_module.models["critic"], DistributedDataParallel + ): + value_normalizer = self.trainer.algo_module.models[ + "critic" + ].module.value_normalizer + elif "model" in self.trainer.algo_module.models and isinstance( + self.trainer.algo_module.models["model"], DistributedDataParallel + ): + value_normalizer = self.trainer.algo_module.models["model"].value_normalizer + else: + value_normalizer = self.trainer.algo_module.get_critic_value_normalizer() + self.buffer.compute_returns(next_values, value_normalizer) + + @torch.no_grad() + def act( + self, + step: int, + ): + self.trainer.prep_rollout() + + ( + value, + action, + action_log_prob, + rnn_states, + rnn_states_critic, + ) = self.trainer.algo_module.get_actions( + self.buffer.data.get_batch_data("critic_obs", step), + self.buffer.data.get_batch_data("policy_obs", step), + np.concatenate(self.buffer.data.rnn_states[step]), + np.concatenate(self.buffer.data.rnn_states_critic[step]), + np.concatenate(self.buffer.data.masks[step]), + ) + + values = np.array(np.split(_t2n(value), self.n_rollout_threads)) + actions = np.array(np.split(_t2n(action), self.n_rollout_threads)) + action_log_probs = np.array( + np.split(_t2n(action_log_prob), self.n_rollout_threads) + ) + rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads)) + rnn_states_critic = np.array( + np.split(_t2n(rnn_states_critic), self.n_rollout_threads) + ) + + return ( + values, + actions, + action_log_probs, + rnn_states, + rnn_states_critic, + ) \ No newline at end of file diff --git a/openrl/modules/common/__init__.py b/openrl/modules/common/__init__.py index bc05184d..974f2e12 100644 --- a/openrl/modules/common/__init__.py +++ b/openrl/modules/common/__init__.py @@ -1,5 +1,7 @@ from .ppo_net import PPONet +from .dqn_net import DQNNet __all__ = [ "PPONet", + "DQNNet", ] diff --git a/openrl/modules/common/dqn_net.py b/openrl/modules/common/dqn_net.py new file mode 100644 index 00000000..77661d77 --- /dev/null +++ b/openrl/modules/common/dqn_net.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +from typing import Any, Dict, Optional, Tuple, Union + +import gymnasium as gym +import numpy as np +import torch + +from openrl.configs.config import create_config_parser +from openrl.modules.common.base_net import BaseNet +from openrl.modules.dqn_module import DQNModule +from openrl.utils.util import set_seed + + +class DQNNet(BaseNet): + def __init__( + self, + env: Union[gym.Env, str], + cfg=None, + device: Union[torch.device, str] = "cpu", + n_rollout_threads: int = 1, + model_dict: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + + if cfg is None: + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args() + + set_seed(cfg.seed) + env.reset(seed=cfg.seed) + + cfg.n_rollout_threads = n_rollout_threads + cfg.learner_n_rollout_threads = cfg.n_rollout_threads + + if cfg.rnn_type == "gru": + rnn_hidden_size = cfg.hidden_size + elif cfg.rnn_type == "lstm": + rnn_hidden_size = cfg.hidden_size * 2 + else: + raise NotImplementedError( + f"RNN type {cfg.rnn_type} has not been implemented." + ) + cfg.rnn_hidden_size = rnn_hidden_size + + if isinstance(device, str): + device = torch.device(device) + + self.module = DQNModule( + cfg=cfg, + input_space=env.observation_space, + act_space=env.action_space, + device=device, + rank=0, + world_size=1, + model_dict=model_dict, + ) + + self.cfg = cfg + self.env = env + self.device = device + self.rnn_states_actor = None + self.masks = None + + def act( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + actions, self.rnn_states_actor = self.module.act( + obs=observation, + rnn_states_actor=self.rnn_states_actor, + masks=self.masks, + available_actions=None, + ) + + return actions, self.rnn_states_actor + + def reset(self, env: Optional[gym.Env] = None) -> None: + if env is not None: + self.env = env + self.first_reset = False + self.rnn_states_actor, self.masks = self.module.init_rnn_states( + rollout_num=self.env.parallel_env_num, + agent_num=self.env.agent_num, + rnn_layers=self.cfg.recurrent_N, + hidden_size=self.cfg.rnn_hidden_size, + ) + + def load_policy(self, path: str) -> None: + self.module.load_policy(path) diff --git a/openrl/modules/dqn_module.py b/openrl/modules/dqn_module.py new file mode 100644 index 00000000..e640ebbe --- /dev/null +++ b/openrl/modules/dqn_module.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +from typing import Any, Dict, Optional, Union + +import gym +import numpy as np +import torch + +from openrl.modules.model_config import ModelTrainConfig +from openrl.modules.networks.value_network import ValueNetwork +from openrl.modules.rl_module import RLModule +from openrl.modules.utils.util import update_linear_schedule + + +class DQNModule(RLModule): + def __init__( + self, + cfg, + input_space: gym.spaces.Box, + act_space: gym.spaces.Box, + device: Union[str, torch.device] = "cpu", + rank: Optional[int] = None, + world_size: Optional[int] = None, + model_dict: Optional[Dict[str, Any]] = None, + ): + model_configs = {} + + + model_configs["q_net"] = ModelTrainConfig( + lr=cfg.lr, + model=( + model_dict["q_net"] + if model_dict and "q_net" in model_dict + else ValueNetwork + ), + input_space=input_space, + ) + model_configs["target_q_net"] = ModelTrainConfig( + lr=cfg.lr, + model=( + model_dict["target_q_net"] + if model_dict and "target_q_net" in model_dict + else ValueNetwork + ), + input_space=input_space, + ) + + super(DQNModule, self).__init__( + cfg=cfg, + model_configs=model_configs, + act_space=act_space, + rank=rank, + world_size=world_size, + device=device, + ) + self.cfg = cfg + + def lr_decay(self, episode, episodes): + + update_linear_schedule(self.optimizers["q_net"], episode, episodes, self.lr) + + def get_actions( + self, + obs, + rnn_states_actor, + masks, + available_actions=None, + ): + + values, actions, rnn_states_actor = self.models["q_net"]( + "original", + obs, + rnn_states_actor, + masks, + available_actions, + ) + + return values, actions, rnn_states_actor + + def get_values(self, critic_obs, rnn_states_critic, masks): + values, _ = self.models["q_net"](critic_obs, rnn_states_critic, masks) + return values + + def evaluate_actions( + self, + obs, + rnn_states, + masks, + available_actions=None, + masks_batch=None, + ): + if masks_batch is None: + masks_batch = masks + + values, _ = self.models["q_net"]( + obs, rnn_states, masks_batch, available_actions + ) + + return values + + def act( + self, obs, rnn_states_actor, masks, available_actions=None + ): + + model = self.models["q_net"] + + actions, _, rnn_states_actor = model( + obs, + rnn_states_actor, + masks, + available_actions, + ) + + return actions, rnn_states_actor + + def get_critic_value_normalizer(self): + return self.models["q_net"].value_normalizer + + @staticmethod + def init_rnn_states( + rollout_num: int, agent_num: int, rnn_layers: int, hidden_size: int + ): + masks = np.ones((rollout_num * agent_num, 1), dtype=np.float32) + rnn_state = np.zeros((rollout_num * agent_num, rnn_layers, hidden_size)) + return rnn_state, masks diff --git a/openrl/modules/networks/q_network.py b/openrl/modules/networks/q_network.py new file mode 100644 index 00000000..f8ca8995 --- /dev/null +++ b/openrl/modules/networks/q_network.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2021 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import torch + +from openrl.buffers.utils.util import get_policy_obs, get_policy_obs_space +from openrl.modules.networks.base_policy_network import BasePolicyNetwork +from openrl.modules.networks.utils.act import ACTLayer +from openrl.modules.networks.utils.cnn import CNNBase +from openrl.modules.networks.utils.mix import MIXBase +from openrl.modules.networks.utils.mlp import MLPBase +from openrl.modules.networks.utils.rnn import RNNLayer +from openrl.utils.util import check_v2 as check + + +class QNetwork(BasePolicyNetwork): + def __init__( + self, + cfg, + input_space, + action_space, + device=torch.device("cpu"), + use_half=False, + ) -> None: + super(QNetwork, self).__init__(cfg, device) + self.hidden_size = cfg.hidden_size + + self._gain = cfg.gain + self._use_orthogonal = cfg.use_orthogonal + self._activation_id = cfg.activation_id + self._use_policy_active_masks = cfg.use_policy_active_masks + self._use_naive_recurrent_policy = cfg.use_naive_recurrent_policy + self._use_recurrent_policy = cfg.use_recurrent_policy + self._recurrent_N = cfg.recurrent_N + self.use_half = use_half + self.tpdv = dict(dtype=torch.float32, device=device) + + policy_obs_shape = get_policy_obs_space(input_space) + + if "Dict" in policy_obs_shape.__class__.__name__: + self._mixed_obs = True + self.base = MIXBase( + cfg, policy_obs_shape, cnn_layers_params=cfg.cnn_layers_params + ) + else: + self._mixed_obs = False + self.base = ( + CNNBase(cfg, policy_obs_shape) + if len(policy_obs_shape) == 3 + else MLPBase( + cfg, + policy_obs_shape, + use_attn_internal=cfg.use_attn_internal, + use_cat_self=True, + ) + ) + + input_size = self.base.output_size + + if self._use_naive_recurrent_policy or self._use_recurrent_policy: + self.rnn = RNNLayer( + input_size, + self.hidden_size, + self._recurrent_N, + self._use_orthogonal, + rnn_type=cfg.rnn_type, + ) + input_size = self.hidden_size + + self.act = ACTLayer(action_space, input_size, self._use_orthogonal, self._gain) + + if use_half: + self.half() + self.to(device) + + def forward( + self, raw_obs, rnn_states, masks, available_actions=None, deterministic=False + ): + policy_obs = get_policy_obs(raw_obs) + if self._mixed_obs: + for key in policy_obs.keys(): + policy_obs[key] = check(policy_obs[key], self.use_half, self.tpdv) + if self.use_half: + policy_obs[key].half() + else: + policy_obs = check(policy_obs, self.use_half, self.tpdv) + + rnn_states = check(rnn_states, self.use_half, self.tpdv) + masks = check(masks, self.use_half, self.tpdv) + + if available_actions is not None: + available_actions = check(available_actions, self.use_half, self.tpdv) + + actor_features = self.base(policy_obs) + + if self._use_naive_recurrent_policy or self._use_recurrent_policy: + actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) + + actions, action_log_probs = self.act( + actor_features, available_actions, deterministic + ) + return actions, action_log_probs, rnn_states + diff --git a/openrl/runners/common/__init__.py b/openrl/runners/common/__init__.py index 1f427013..fb88aad3 100644 --- a/openrl/runners/common/__init__.py +++ b/openrl/runners/common/__init__.py @@ -1,8 +1,11 @@ from openrl.runners.common.chat_agent import Chat6BAgent, ChatAgent from openrl.runners.common.ppo_agent import PPOAgent +from openrl.runners.common.dqn_agent import DQNAgent + __all__ = [ "PPOAgent", "ChatAgent", "Chat6BAgent", + "DQNAgent", ] diff --git a/openrl/runners/common/dqn_agent.py b/openrl/runners/common/dqn_agent.py new file mode 100644 index 00000000..fa11706a --- /dev/null +++ b/openrl/runners/common/dqn_agent.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import Dict, Optional, Tuple, Union + +import gym +import numpy as np +import torch + +from openrl.algorithms.dqn import DQNAlgorithm as TrainAlgo +from openrl.buffers import NormalReplayBuffer as ReplayBuffer +from openrl.buffers.utils.obs_data import ObsData +from openrl.drivers.offpolicy_driver import OffPolicyDriver as Driver +from openrl.runners.common.rl_agent import RLAgent +from openrl.runners.common.base_agent import SelfAgent +from openrl.utils.logger import Logger +from openrl.utils.util import _t2n + + +class DQNAgent(RLAgent): + def __init__( + self, + net: Optional[torch.nn.Module] = None, + env: Union[gym.Env, str] = None, + run_dir: Optional[str] = None, + env_num: Optional[int] = None, + rank: int = 0, + world_size: int = 1, + use_wandb: bool = False, + use_tensorboard: bool = False, + ) -> None: + super(DQNAgent, self).__init__(net, env, run_dir, env_num, rank, world_size, use_wandb, use_tensorboard) + + def train(self: SelfAgent, total_time_steps: int) -> None: + self._cfg.num_env_steps = total_time_steps + + self.config = { + "cfg": self._cfg, + "num_agents": self.agent_num, + "run_dir": self.run_dir, + "envs": self._env, + "device": self.net.device, + } + + trainer = TrainAlgo( + cfg=self._cfg, + init_module=self.net.module, + device=self.net.device, + agent_num=self.agent_num, + ) + + buffer = ReplayBuffer( + self._cfg, + self.agent_num, + self._env.observation_space, + self._env.action_space, + data_client=None, + ) + + logger = Logger( + cfg=self._cfg, + project_name="DQNAgent", + scenario_name=self._env.env_name, + wandb_entity=self._cfg.wandb_entity, + exp_name=self.exp_name, + log_path=self.run_dir, + use_wandb=self._use_wandb, + use_tensorboard=self._use_tensorboard, + ) + driver = Driver( + config=self.config, + trainer=trainer, + buffer=buffer, + client=self.client, + rank=self.rank, + world_size=self.world_size, + logger=logger, + ) + driver.run() + + def act( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + assert self.net is not None, "net is None" + observation = ObsData.prepare_input(observation) + action, rnn_state = self.net.act(observation) + + action = np.array(np.split(_t2n(action), self.env_num)) + + return action, rnn_state \ No newline at end of file diff --git a/openrl/runners/common/ppo_agent.py b/openrl/runners/common/ppo_agent.py index 35854252..f782bb21 100644 --- a/openrl/runners/common/ppo_agent.py +++ b/openrl/runners/common/ppo_agent.py @@ -15,17 +15,20 @@ # limitations under the License. """""" -from typing import Optional, Union +from typing import Dict, Optional, Tuple, Union import gym +import numpy as np import torch from openrl.algorithms.ppo import PPOAlgorithm as TrainAlgo from openrl.buffers import NormalReplayBuffer as ReplayBuffer +from openrl.buffers.utils.obs_data import ObsData from openrl.drivers.onpolicy_driver import OnPolicyDriver as Driver from openrl.runners.common.base_agent import SelfAgent from openrl.runners.common.rl_agent import RLAgent from openrl.utils.logger import Logger +from openrl.utils.util import _t2n class PPOAgent(RLAgent): @@ -90,3 +93,16 @@ def train(self: SelfAgent, total_time_steps: int) -> None: logger=logger, ) driver.run() + + def act( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + deterministic: bool = True, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + assert self.net is not None, "net is None" + observation = ObsData.prepare_input(observation) + action, rnn_state = self.net.act(observation, deterministic=deterministic) + + action = np.array(np.split(_t2n(action), self.env_num)) + + return action, rnn_state diff --git a/openrl/runners/common/rl_agent.py b/openrl/runners/common/rl_agent.py index 934825c3..6b06f079 100644 --- a/openrl/runners/common/rl_agent.py +++ b/openrl/runners/common/rl_agent.py @@ -24,9 +24,7 @@ import numpy as np import torch -from openrl.buffers.utils.obs_data import ObsData from openrl.runners.common.base_agent import BaseAgent, SelfAgent -from openrl.utils.util import _t2n class RLAgent(BaseAgent): @@ -87,18 +85,12 @@ def __init__( def train(self: SelfAgent, total_time_steps: int) -> None: raise NotImplementedError + @abstractmethod def act( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - deterministic: bool = True, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: - assert self.net is not None, "net is None" - observation = ObsData.prepare_input(observation) - action, rnn_state = self.net.act(observation, deterministic=deterministic) - - action = np.array(np.split(_t2n(action), self.env_num)) - - return action, rnn_state + **kwargs + ) -> None: + raise NotImplementedError def set_env( self, diff --git a/tests/test_algorithm/test_dqn_algorithm.py b/tests/test_algorithm/test_dqn_algorithm.py new file mode 100644 index 00000000..7ede3282 --- /dev/null +++ b/tests/test_algorithm/test_dqn_algorithm.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + + +@pytest.fixture +def obs_space(): + return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture +def act_space(): + return spaces.Discrete(2) + + +@pytest.fixture(scope="module", params=[""]) +def config(request): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.fixture +def amp_config(): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args("") + return cfg + + +@pytest.fixture +def init_module(config, obs_space, act_space): + from openrl.modules.dqn_module import DQNModule + + module = DQNModule( + config, + input_space=obs_space, + act_space=act_space, + ) + return module + + +@pytest.fixture +def buffer_data(config, obs_space, act_space): + from openrl.buffers.normal_buffer import NormalReplayBuffer + + buffer = NormalReplayBuffer( + config, + num_agents=1, + obs_space=obs_space, + act_space=act_space, + data_client=None, + episode_length=5000, + ) + return buffer.data + + +@pytest.mark.unittest +def test_dqn_algorithm(config, init_module, buffer_data): + from openrl.algorithms.dqn import DQNAlgorithm + + dqn_algo = DQNAlgorithm(config, init_module) + + # dqn_algo.train(buffer_data) + + +@pytest.mark.unittest +def test_dqn_algorithm_amp(config, init_module, buffer_data): + from openrl.algorithms.dqn import DQNAlgorithm + + dqn_algo = DQNAlgorithm(config, init_module) + + # dqn_algo.train(buffer_data) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))