Skip to content

Commit

Permalink
Merge branch 'dev-ptz' of https://github.com/puyuan1996/DI-engine int…
Browse files Browse the repository at this point in the history
…o dev-ptz
  • Loading branch information
puyuan1996 committed Nov 24, 2024
2 parents 62b2f23 + f2d82ea commit c28c85a
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 33 deletions.
70 changes: 56 additions & 14 deletions ding/model/template/qmix.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Union, List
from functools import reduce
from typing import List, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
from ding.utils import list_split, MODEL_REGISTRY
from ding.torch_utils import fc_block, MLP
from ..common import ConvEncoder
from ding.torch_utils import MLP, fc_block
from ding.utils import MODEL_REGISTRY, list_split

from ..common import ConvEncoder
from .q_learning import DRQN


Expand Down Expand Up @@ -147,14 +149,34 @@ def __init__(
embedding_size = hidden_size_list[-1]
self.mixer = mixer
if self.mixer:
if len(global_obs_shape) == 1:
global_obs_shape_type = self._get_global_obs_shape_type(global_obs_shape)

if global_obs_shape_type == "flat":
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
self._global_state_encoder = nn.Identity()
elif len(global_obs_shape) == 3:
elif global_obs_shape_type == "image":
self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation)
self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN')
self._global_state_encoder = ConvEncoder(
global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN'
)
else:
raise ValueError("Not support global_obs_shape: {}".format(global_obs_shape))
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")

def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) -> str:
"""
Overview:
Determine the type of global observation shape.
Arguments:
- global_obs_shape (:obj:`int` or :obj:`List[int]`): The global observation state.
Returns:
- str: 'flat' for 1D observation or 'image' for 3D observation.
"""
if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1):
return "flat"
elif isinstance(global_obs_shape, list) and len(global_obs_shape) == 3:
return "image"
else:
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")

def forward(self, data: dict, single_step: bool = True) -> dict:
"""
Expand Down Expand Up @@ -214,18 +236,38 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
agent_q_act = agent_q_act.squeeze(-1) # T, B, A
if self.mixer:
if len(global_state.shape) == 5:
global_state_embedding = self._global_state_encoder(global_state.reshape(-1, *global_state.shape[-3:])).reshape(global_state.shape[0], global_state.shape[1], -1)
else:
global_state_embedding = self._global_state_encoder(global_state)
global_state_embedding = self._process_global_state(global_state)
total_q = self._mixer(agent_q_act, global_state_embedding)
else:
total_q = agent_q_act.sum(-1)
total_q = agent_q_act.sum(dim=-1)

if single_step:
total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0)

return {
'total_q': total_q,
'logit': agent_q,
'next_state': next_state,
'action_mask': data['obs']['action_mask']
}

def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor:
"""
Overview:
Process the global state to obtain an embedding.
Arguments:
- global_state (:obj:`torch.Tensor`): The global state tensor.
Returns:
- (:obj:`torch.Tensor`): The processed global state embedding.
"""
# If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W]
if global_state.dim() == 5:
# Reshape and apply the global state encoder
batch_time_shape = global_state.shape[:2] # [batch_size, time_steps]
reshaped_state = global_state.view(-1, *global_state.shape[-3:]) # Collapse batch and time dims
encoded_state = self._global_state_encoder(reshaped_state)
return encoded_state.view(*batch_time_shape, -1) # Reshape back to [batch_size, time_steps, embedding_dim]
else:
# For lower-dimensional states, apply the encoder directly
return self._global_state_encoder(global_state)
32 changes: 18 additions & 14 deletions dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from easydict import EasyDict

# n_pistons = 20
n_pistons = 2
n_pistons = 20
collector_env_num = 8
evaluator_env_num = 8

main_config = dict(
exp_name='data_pistonball/ptz_pistonball_qmix_seed0',
exp_name=f'data_pistonball/ptz_pistonball_n{n_pistons}_qmix_seed0',
env=dict(
env_family='butterfly',
env_id='pistonball_v6',
Expand All @@ -28,30 +27,35 @@
obs_shape=(3, 457, 120), # RGB image observation shape for each piston agent
global_obs_shape=(3, 560, 880), # Global state shape
action_shape=3, # Discrete actions (0, 1, 2)
hidden_size_list=[128, 128, 64],
hidden_size_list=[32, 64, 128, 256],
mixer=True,
),
learn=dict(
update_per_collect=20,
batch_size=16,
batch_size=32,
learning_rate=0.0001,
clip_value=5,
target_update_theta=0.001,
discount_factor=0.99,
double_q=True,
clip_value=10,
),
collect=dict(
n_sample=32,
unroll_len=16,
n_sample=16,
unroll_len=5,
env_num=collector_env_num,
),
eval=dict(env_num=evaluator_env_num),
other=dict(eps=dict(
type='exp',
start=1.0,
end=0.05,
decay=100000,
)),
other=dict(
eps=dict(
type='exp',
start=1.0,
end=0.05,
decay=100000,
),
replay_buffer=dict(
replay_buffer_size=5000,
),
),
),
)
main_config = EasyDict(main_config)
Expand Down
32 changes: 27 additions & 5 deletions dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from functools import reduce
from typing import List, Optional, Dict
from typing import Dict, List, Optional

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -31,6 +32,7 @@ def __init__(self, cfg: dict) -> None:
if self._act_scale:
assert self._continuous_actions, 'Action scaling only applies to continuous action spaces.'
self._channel_first = self._cfg.get('channel_first', True)
self.normalize_reward = self._cfg.normalize_reward

def reset(self) -> np.ndarray:
"""
Expand Down Expand Up @@ -127,11 +129,17 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
obs_n = self._process_obs(obs)
rew_n = np.array([sum([rew[agent] for agent in self._agents])])
rew_n = rew_n.astype(np.float32)

if self.normalize_reward:
# TODO: more elegant scale factor
rew_n = rew_n / (self._num_pistons*50)

self._eval_episode_return += rew_n.item()

done_n = reduce(lambda x, y: x and y, done.values()) or self._step_count >= self._max_cycles
if done_n:
info['eval_episode_return'] = self._eval_episode_return


return BaseEnvTimestep(obs_n, rew_n, done_n, info)

Expand Down Expand Up @@ -202,9 +210,6 @@ def random_action(self) -> np.ndarray:
random_action[k] = to_ndarray([random_action[k]], dtype=np.int64)
return random_action

def __repr__(self) -> str:
return "DI-engine PettingZoo Pistonball Env"

@property
def agents(self) -> List[str]:
return self._agents
Expand All @@ -219,4 +224,21 @@ def action_space(self) -> gym.spaces.Space:

@property
def reward_space(self) -> gym.spaces.Space:
return self._reward_space
return self._reward_space

@staticmethod
def create_collector_env_cfg(cfg: dict) -> List[dict]:
collector_env_num = cfg.pop('collector_env_num')
cfg = copy.deepcopy(cfg)
cfg.normalize_reward = True
return [cfg for _ in range(collector_env_num)]

@staticmethod
def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
evaluator_env_num = cfg.pop('evaluator_env_num')
cfg = copy.deepcopy(cfg)
cfg.normalize_reward = False
return [cfg for _ in range(evaluator_env_num)]

def __repr__(self) -> str:
return "DI-engine PettingZoo Pistonball Env"

0 comments on commit c28c85a

Please sign in to comment.