diff --git a/examples/mpe/mpe_vdn.yaml b/examples/mpe/mpe_vdn.yaml index 08961b77..e17be7e4 100644 --- a/examples/mpe/mpe_vdn.yaml +++ b/examples/mpe/mpe_vdn.yaml @@ -1,6 +1,7 @@ seed: 0 lr: 7e-4 -episode_length: 25 +episode_length: 200 +num_mini_batch: 128 run_dir: ./run_results/ experiment_name: train_mpe_vdn log_interval: 10 diff --git a/examples/mpe/train_vdn.py b/examples/mpe/train_vdn.py index a96b4c31..53b44e54 100644 --- a/examples/mpe/train_vdn.py +++ b/examples/mpe/train_vdn.py @@ -29,7 +29,7 @@ def train(): # start training agent.train(total_time_steps=5000000) env.close() - agent.save("./mat_agent/") + agent.save("./vdn_agent/") return agent diff --git a/openrl/algorithms/vdn.py b/openrl/algorithms/vdn.py index 7930e08e..cff0c542 100644 --- a/openrl/algorithms/vdn.py +++ b/openrl/algorithms/vdn.py @@ -41,8 +41,10 @@ def __init__( self.gamma = cfg.gamma self.n_agent = cfg.num_agents + self.update_count = 0 + self.target_update_frequency = cfg.train_interval - def dqn_update(self, sample, turn_on=True): + def vdn_update(self, sample, turn_on=True): for optimizer in self.algo_module.optimizers.values(): optimizer.zero_grad() @@ -120,6 +122,12 @@ def dqn_update(self, sample, turn_on=True): if self.world_size > 1: torch.cuda.synchronize() + if self.update_count % self.target_update_frequency == 0: + self.update_count = 0 + self.algo_module.models["target_vdn_net"].load_state_dict( + self.algo_module.models["vdn_net"].state_dict() + ) + return loss def cal_value_loss( @@ -198,9 +206,11 @@ def prepare_loss( ) rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1) - rewards_batch = torch.sum(rewards_batch, dim=2, keepdim=True).view(-1, 1) + rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1) q_targets = rewards_batch + self.gamma * max_next_q_values - q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数 + q_loss = torch.mean( + F.mse_loss(q_values, q_targets.detach()) + ) # 均方误差损失函数 loss_list.append(q_loss) return loss_list @@ -225,271 +235,10 @@ def train(self, buffer, turn_on=True): data_generator = buffer.feed_forward_generator( None, num_mini_batch=self.num_mini_batch, - # mini_batch_size=self.mini_batch_size, - ) - - for sample in data_generator: - (q_loss) = self.dqn_update(sample, turn_on) - print(q_loss) - if self.world_size > 1: - train_info["reduced_q_loss"] += reduce_tensor( - q_loss.data, self.world_size - ) - - train_info["q_loss"] += q_loss.item() - - self.algo_module.models["target_vdn_net"].load_state_dict( - self.algo_module.models["vdn_net"].state_dict() - ) - num_updates = 1 * self.num_mini_batch - - for k in train_info.keys(): - train_info[k] /= num_updates - - for optimizer in self.algo_module.optimizers.values(): - if hasattr(optimizer, "sync_lookahead"): - optimizer.sync_lookahead() - - return train_info - - -''' - -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright 2023 The OpenRL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""""" - -from typing import Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from openrl.algorithms.base_algorithm import BaseAlgorithm -from openrl.modules.networks.utils.distributed_utils import reduce_tensor -from openrl.modules.utils.util import get_gard_norm, huber_loss, mse_loss -from openrl.utils.util import check - - -class VDNAlgorithm(BaseAlgorithm): - def __init__( - self, - cfg, - init_module, - agent_num: int = 1, - device: Union[str, torch.device] = "cpu", - ) -> None: - super(VDNAlgorithm, self).__init__(cfg, init_module, agent_num, device) - - self.gamma = cfg.gamma - self.n_agent = cfg.num_agents - self.parallel_env_num = cfg.parallel_env_num - - def dqn_update(self, sample, turn_on=True): - for optimizer in self.algo_module.optimizers.values(): - optimizer.zero_grad() - - ( - obs_batch, - _, - next_obs_batch, - _, - rnn_states_batch, - rnn_states_critic_batch, - actions_batch, - value_preds_batch, - rewards_batch, - masks_batch, - active_masks_batch, - old_action_log_probs_batch, - adv_targ, - available_actions_batch, - ) = sample - - value_preds_batch = check(value_preds_batch).to(**self.tpdv) - rewards_batch = check(rewards_batch).to(**self.tpdv) - active_masks_batch = check(active_masks_batch).to(**self.tpdv) - - if self.use_amp: - with torch.cuda.amp.autocast(): - loss_list = self.prepare_loss( - obs_batch, - next_obs_batch, - rnn_states_batch, - actions_batch, - masks_batch, - available_actions_batch, - value_preds_batch, - rewards_batch, - active_masks_batch, - turn_on, - ) - for loss in loss_list: - self.algo_module.scaler.scale(loss).backward() - else: - loss_list = self.prepare_loss( - obs_batch, - next_obs_batch, - rnn_states_batch, - actions_batch, - masks_batch, - available_actions_batch, - value_preds_batch, - rewards_batch, - active_masks_batch, - turn_on, - ) - for loss in loss_list: - loss.backward() - - if "transformer" in self.algo_module.models: - raise NotImplementedError - else: - actor_para = self.algo_module.models["vdn_net"].parameters() - actor_grad_norm = get_gard_norm(actor_para) - - if self.use_amp: - for optimizer in self.algo_module.optimizers.values(): - self.algo_module.scaler.unscale_(optimizer) - - for optimizer in self.algo_module.optimizers.values(): - self.algo_module.scaler.step(optimizer) - - self.algo_module.scaler.update() - else: - for optimizer in self.algo_module.optimizers.values(): - optimizer.step() - - if self.world_size > 1: - torch.cuda.synchronize() - - return loss - - def cal_value_loss( - self, - value_normalizer, - values, - value_preds_batch, - return_batch, - active_masks_batch, - ): - value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp( - -self.clip_param, self.clip_param - ) - - if self._use_popart or self._use_valuenorm: - value_normalizer.update(return_batch) - error_clipped = ( - value_normalizer.normalize(return_batch) - value_pred_clipped - ) - error_original = value_normalizer.normalize(return_batch) - values - else: - error_clipped = return_batch - value_pred_clipped - error_original = return_batch - values - - if self._use_huber_loss: - value_loss_clipped = huber_loss(error_clipped, self.huber_delta) - value_loss_original = huber_loss(error_original, self.huber_delta) - else: - value_loss_clipped = mse_loss(error_clipped) - value_loss_original = mse_loss(error_original) - - if self._use_clipped_value_loss: - value_loss = torch.max(value_loss_original, value_loss_clipped) - else: - value_loss = value_loss_original - - if self._use_value_active_masks: - value_loss = ( - value_loss * active_masks_batch - ).sum() / active_masks_batch.sum() - else: - value_loss = value_loss.mean() - - return value_loss - - def to_single_np(self, input): - reshape_input = input.reshape(-1, self.agent_num, *input.shape[1:]) - return reshape_input[:, 0, ...] - - def prepare_loss( - self, - obs_batch, - next_obs_batch, - rnn_states_batch, - actions_batch, - masks_batch, - available_actions_batch, - value_preds_batch, - rewards_batch, - active_masks_batch, - turn_on, - ): - loss_list = [] - critic_masks_batch = masks_batch - - (q_values, max_next_q_values) = self.algo_module.evaluate_actions( - obs_batch, - next_obs_batch, - rnn_states_batch, - rewards_batch, - actions_batch, - masks_batch, - available_actions_batch, - active_masks_batch, - critic_masks_batch=critic_masks_batch, - ) - - rewards_batch = rewards_batch.reshape( - -1, self.parallel_env_num, self.n_agent, 1 - ) - rewards_batch = torch.sum(rewards_batch, dim=2, keepdim=True).view(-1, 1) - q_targets = rewards_batch + self.gamma * max_next_q_values - q_loss = torch.mean(F.mse_loss(q_values, q_targets)) # 均方误差损失函数 - - loss_list.append(q_loss) - return loss_list - - def train(self, buffer, turn_on=True): - train_info = {} - - train_info["q_loss"] = 0 - - if self.world_size > 1: - train_info["reduced_q_loss"] = 0 - - # todo add rnn and transformer - # update once - for _ in range(1): - if "transformer" in self.algo_module.models: - raise NotImplementedError - elif self._use_recurrent_policy: - raise NotImplementedError - elif self._use_naive_recurrent: - raise NotImplementedError - else: - data_generator = buffer.feed_forward_generator( - None, self.num_mini_batch ) for sample in data_generator: - (q_loss) = self.dqn_update(sample, turn_on) - + (q_loss) = self.vdn_update(sample, turn_on) if self.world_size > 1: train_info["reduced_q_loss"] += reduce_tensor( q_loss.data, self.world_size @@ -507,4 +256,3 @@ def train(self, buffer, turn_on=True): optimizer.sync_lookahead() return train_info -''' diff --git a/openrl/runners/common/vdn_agent.py b/openrl/runners/common/vdn_agent.py index 94594c8e..c37825fe 100644 --- a/openrl/runners/common/vdn_agent.py +++ b/openrl/runners/common/vdn_agent.py @@ -76,7 +76,7 @@ def train(self: SelfAgent, total_time_steps: int) -> None: logger = Logger( cfg=self._cfg, - project_name="DQNAgent", + project_name="VDNAgent", scenario_name=self._env.env_name, wandb_entity=self._cfg.wandb_entity, exp_name=self.exp_name,