Skip to content

Commit

Permalink
polish(pu): polish qmix.py
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 14, 2024
1 parent 5a0bdd8 commit 9285a84
Showing 1 changed file with 50 additions and 9 deletions.
59 changes: 50 additions & 9 deletions ding/model/template/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,36 @@ def __init__(
embedding_size = hidden_size_list[-1]
self.mixer = mixer
if self.mixer:
if len(global_obs_shape) == 1:
global_obs_shape_type = self._get_global_obs_shape_type(global_obs_shape)

if global_obs_shape_type == "flat":
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
self._global_state_encoder = nn.Identity()
elif len(global_obs_shape) == 3:
elif global_obs_shape_type == "image":
self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation)
self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN')
self._global_state_encoder = ConvEncoder(
global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN'
)
else:
raise ValueError("Not support global_obs_shape: {}".format(global_obs_shape))
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")

def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) -> str:
"""
Overview:
Determine the type of global observation shape.
Arguments:
- global_obs_shape (:obj:`int` or :obj:`List[int]`): The global observation state.
Returns:
- str: 'flat' for 1D observation or 'image' for 3D observation.
"""
if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1):
return "flat"
elif isinstance(global_obs_shape, list) and len(global_obs_shape) == 3:
return "image"
else:
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")

def forward(self, data: dict, single_step: bool = True) -> dict:
"""
Expand Down Expand Up @@ -214,18 +236,37 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
agent_q_act = agent_q_act.squeeze(-1) # T, B, A
if self.mixer:
if len(global_state.shape) == 5:
global_state_embedding = self._global_state_encoder(global_state.reshape(-1, *global_state.shape[-3:])).reshape(global_state.shape[0], global_state.shape[1], -1)
else:
global_state_embedding = self._global_state_encoder(global_state)
global_state_embedding = self._process_global_state(global_state)
total_q = self._mixer(agent_q_act, global_state_embedding)
else:
total_q = agent_q_act.sum(-1)
total_q = agent_q_act.sum(dim=-1)

if single_step:
total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0)

return {
'total_q': total_q,
'logit': agent_q,
'next_state': next_state,
'action_mask': data['obs']['action_mask']
}
def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor:
"""
Process the global state to obtain an embedding.
Arguments:
- global_state (:obj:`torch.Tensor`): The global state tensor.
Returns:
- (:obj:`torch.Tensor`): The processed global state embedding.
"""
# If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W]
if global_state.dim() == 5:
# Reshape and apply the global state encoder
batch_time_shape = global_state.shape[:2] # [batch_size, time_steps]
reshaped_state = global_state.view(-1, *global_state.shape[-3:]) # Collapse batch and time dims
encoded_state = self._global_state_encoder(reshaped_state)
return encoded_state.view(*batch_time_shape, -1) # Reshape back to [batch_size, time_steps, embedding_dim]
else:
# For lower-dimensional states, apply the encoder directly
return self._global_state_encoder(global_state)

0 comments on commit 9285a84

Please sign in to comment.