Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize vdn algorithm v3 #108

Merged
merged 1 commit into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/mpe/mpe_vdn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
seed: 0
lr: 7e-4
episode_length: 25
episode_length: 200
num_mini_batch: 128
run_dir: ./run_results/
experiment_name: train_mpe_vdn
log_interval: 10
Expand Down
2 changes: 1 addition & 1 deletion examples/mpe/train_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def train():
# start training
agent.train(total_time_steps=5000000)
env.close()
agent.save("./mat_agent/")
agent.save("./vdn_agent/")
return agent


Expand Down
280 changes: 14 additions & 266 deletions openrl/algorithms/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ def __init__(

self.gamma = cfg.gamma
self.n_agent = cfg.num_agents
self.update_count = 0
self.target_update_frequency = cfg.train_interval

def dqn_update(self, sample, turn_on=True):
def vdn_update(self, sample, turn_on=True):
for optimizer in self.algo_module.optimizers.values():
optimizer.zero_grad()

Expand Down Expand Up @@ -120,6 +122,12 @@ def dqn_update(self, sample, turn_on=True):
if self.world_size > 1:
torch.cuda.synchronize()

if self.update_count % self.target_update_frequency == 0:
self.update_count = 0
self.algo_module.models["target_vdn_net"].load_state_dict(
self.algo_module.models["vdn_net"].state_dict()
)

return loss

def cal_value_loss(
Expand Down Expand Up @@ -198,9 +206,11 @@ def prepare_loss(
)

rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1)
rewards_batch = torch.sum(rewards_batch, dim=2, keepdim=True).view(-1, 1)
rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1)
q_targets = rewards_batch + self.gamma * max_next_q_values
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
q_loss = torch.mean(
F.mse_loss(q_values, q_targets.detach())
) # 均方误差损失函数

loss_list.append(q_loss)
return loss_list
Expand All @@ -225,271 +235,10 @@ def train(self, buffer, turn_on=True):
data_generator = buffer.feed_forward_generator(
None,
num_mini_batch=self.num_mini_batch,
# mini_batch_size=self.mini_batch_size,
)

for sample in data_generator:
(q_loss) = self.dqn_update(sample, turn_on)
print(q_loss)
if self.world_size > 1:
train_info["reduced_q_loss"] += reduce_tensor(
q_loss.data, self.world_size
)

train_info["q_loss"] += q_loss.item()

self.algo_module.models["target_vdn_net"].load_state_dict(
self.algo_module.models["vdn_net"].state_dict()
)
num_updates = 1 * self.num_mini_batch

for k in train_info.keys():
train_info[k] /= num_updates

for optimizer in self.algo_module.optimizers.values():
if hasattr(optimizer, "sync_lookahead"):
optimizer.sync_lookahead()

return train_info


'''

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

from typing import Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from openrl.algorithms.base_algorithm import BaseAlgorithm
from openrl.modules.networks.utils.distributed_utils import reduce_tensor
from openrl.modules.utils.util import get_gard_norm, huber_loss, mse_loss
from openrl.utils.util import check


class VDNAlgorithm(BaseAlgorithm):
def __init__(
self,
cfg,
init_module,
agent_num: int = 1,
device: Union[str, torch.device] = "cpu",
) -> None:
super(VDNAlgorithm, self).__init__(cfg, init_module, agent_num, device)

self.gamma = cfg.gamma
self.n_agent = cfg.num_agents
self.parallel_env_num = cfg.parallel_env_num

def dqn_update(self, sample, turn_on=True):
for optimizer in self.algo_module.optimizers.values():
optimizer.zero_grad()

(
obs_batch,
_,
next_obs_batch,
_,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
value_preds_batch,
rewards_batch,
masks_batch,
active_masks_batch,
old_action_log_probs_batch,
adv_targ,
available_actions_batch,
) = sample

value_preds_batch = check(value_preds_batch).to(**self.tpdv)
rewards_batch = check(rewards_batch).to(**self.tpdv)
active_masks_batch = check(active_masks_batch).to(**self.tpdv)

if self.use_amp:
with torch.cuda.amp.autocast():
loss_list = self.prepare_loss(
obs_batch,
next_obs_batch,
rnn_states_batch,
actions_batch,
masks_batch,
available_actions_batch,
value_preds_batch,
rewards_batch,
active_masks_batch,
turn_on,
)
for loss in loss_list:
self.algo_module.scaler.scale(loss).backward()
else:
loss_list = self.prepare_loss(
obs_batch,
next_obs_batch,
rnn_states_batch,
actions_batch,
masks_batch,
available_actions_batch,
value_preds_batch,
rewards_batch,
active_masks_batch,
turn_on,
)
for loss in loss_list:
loss.backward()

if "transformer" in self.algo_module.models:
raise NotImplementedError
else:
actor_para = self.algo_module.models["vdn_net"].parameters()
actor_grad_norm = get_gard_norm(actor_para)

if self.use_amp:
for optimizer in self.algo_module.optimizers.values():
self.algo_module.scaler.unscale_(optimizer)

for optimizer in self.algo_module.optimizers.values():
self.algo_module.scaler.step(optimizer)

self.algo_module.scaler.update()
else:
for optimizer in self.algo_module.optimizers.values():
optimizer.step()

if self.world_size > 1:
torch.cuda.synchronize()

return loss

def cal_value_loss(
self,
value_normalizer,
values,
value_preds_batch,
return_batch,
active_masks_batch,
):
value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(
-self.clip_param, self.clip_param
)

if self._use_popart or self._use_valuenorm:
value_normalizer.update(return_batch)
error_clipped = (
value_normalizer.normalize(return_batch) - value_pred_clipped
)
error_original = value_normalizer.normalize(return_batch) - values
else:
error_clipped = return_batch - value_pred_clipped
error_original = return_batch - values

if self._use_huber_loss:
value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
value_loss_original = huber_loss(error_original, self.huber_delta)
else:
value_loss_clipped = mse_loss(error_clipped)
value_loss_original = mse_loss(error_original)

if self._use_clipped_value_loss:
value_loss = torch.max(value_loss_original, value_loss_clipped)
else:
value_loss = value_loss_original

if self._use_value_active_masks:
value_loss = (
value_loss * active_masks_batch
).sum() / active_masks_batch.sum()
else:
value_loss = value_loss.mean()

return value_loss

def to_single_np(self, input):
reshape_input = input.reshape(-1, self.agent_num, *input.shape[1:])
return reshape_input[:, 0, ...]

def prepare_loss(
self,
obs_batch,
next_obs_batch,
rnn_states_batch,
actions_batch,
masks_batch,
available_actions_batch,
value_preds_batch,
rewards_batch,
active_masks_batch,
turn_on,
):
loss_list = []
critic_masks_batch = masks_batch

(q_values, max_next_q_values) = self.algo_module.evaluate_actions(
obs_batch,
next_obs_batch,
rnn_states_batch,
rewards_batch,
actions_batch,
masks_batch,
available_actions_batch,
active_masks_batch,
critic_masks_batch=critic_masks_batch,
)

rewards_batch = rewards_batch.reshape(
-1, self.parallel_env_num, self.n_agent, 1
)
rewards_batch = torch.sum(rewards_batch, dim=2, keepdim=True).view(-1, 1)
q_targets = rewards_batch + self.gamma * max_next_q_values
q_loss = torch.mean(F.mse_loss(q_values, q_targets)) # 均方误差损失函数

loss_list.append(q_loss)
return loss_list

def train(self, buffer, turn_on=True):
train_info = {}

train_info["q_loss"] = 0

if self.world_size > 1:
train_info["reduced_q_loss"] = 0

# todo add rnn and transformer
# update once
for _ in range(1):
if "transformer" in self.algo_module.models:
raise NotImplementedError
elif self._use_recurrent_policy:
raise NotImplementedError
elif self._use_naive_recurrent:
raise NotImplementedError
else:
data_generator = buffer.feed_forward_generator(
None, self.num_mini_batch
)

for sample in data_generator:
(q_loss) = self.dqn_update(sample, turn_on)

(q_loss) = self.vdn_update(sample, turn_on)
if self.world_size > 1:
train_info["reduced_q_loss"] += reduce_tensor(
q_loss.data, self.world_size
Expand All @@ -507,4 +256,3 @@ def train(self, buffer, turn_on=True):
optimizer.sync_lookahead()

return train_info
'''
2 changes: 1 addition & 1 deletion openrl/runners/common/vdn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def train(self: SelfAgent, total_time_steps: int) -> None:

logger = Logger(
cfg=self._cfg,
project_name="DQNAgent",
project_name="VDNAgent",
scenario_name=self._env.env_name,
wandb_entity=self._cfg.wandb_entity,
exp_name=self.exp_name,
Expand Down