From 9285a84285ec47b262c45744e94f98edfdab1854 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cpuyuan1996=E2=80=9D?= <2402552459@qq.com> Date: Fri, 15 Nov 2024 01:09:28 +0800 Subject: [PATCH] polish(pu): polish qmix.py --- ding/model/template/qmix.py | 59 +++++++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/ding/model/template/qmix.py b/ding/model/template/qmix.py index 6d28bac6bc..05ccb4e8e5 100644 --- a/ding/model/template/qmix.py +++ b/ding/model/template/qmix.py @@ -147,14 +147,36 @@ def __init__( embedding_size = hidden_size_list[-1] self.mixer = mixer if self.mixer: - if len(global_obs_shape) == 1: + global_obs_shape_type = self._get_global_obs_shape_type(global_obs_shape) + + if global_obs_shape_type == "flat": self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) self._global_state_encoder = nn.Identity() - elif len(global_obs_shape) == 3: + elif global_obs_shape_type == "image": self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation) - self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN') + self._global_state_encoder = ConvEncoder( + global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN' + ) else: - raise ValueError("Not support global_obs_shape: {}".format(global_obs_shape)) + raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}") + + def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) -> str: + """ + Overview: + Determine the type of global observation shape. + + Arguments: + - global_obs_shape (:obj:`int` or :obj:`List[int]`): The global observation state. + + Returns: + - str: 'flat' for 1D observation or 'image' for 3D observation. + """ + if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1): + return "flat" + elif isinstance(global_obs_shape, list) and len(global_obs_shape) == 3: + return "image" + else: + raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}") def forward(self, data: dict, single_step: bool = True) -> dict: """ @@ -214,18 +236,37 @@ def forward(self, data: dict, single_step: bool = True) -> dict: agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1)) agent_q_act = agent_q_act.squeeze(-1) # T, B, A if self.mixer: - if len(global_state.shape) == 5: - global_state_embedding = self._global_state_encoder(global_state.reshape(-1, *global_state.shape[-3:])).reshape(global_state.shape[0], global_state.shape[1], -1) - else: - global_state_embedding = self._global_state_encoder(global_state) + global_state_embedding = self._process_global_state(global_state) total_q = self._mixer(agent_q_act, global_state_embedding) else: - total_q = agent_q_act.sum(-1) + total_q = agent_q_act.sum(dim=-1) + if single_step: total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0) + return { 'total_q': total_q, 'logit': agent_q, 'next_state': next_state, 'action_mask': data['obs']['action_mask'] } + def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor: + """ + Process the global state to obtain an embedding. + + Arguments: + - global_state (:obj:`torch.Tensor`): The global state tensor. + + Returns: + - (:obj:`torch.Tensor`): The processed global state embedding. + """ + # If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W] + if global_state.dim() == 5: + # Reshape and apply the global state encoder + batch_time_shape = global_state.shape[:2] # [batch_size, time_steps] + reshaped_state = global_state.view(-1, *global_state.shape[-3:]) # Collapse batch and time dims + encoded_state = self._global_state_encoder(reshaped_state) + return encoded_state.view(*batch_time_shape, -1) # Reshape back to [batch_size, time_steps, embedding_dim] + else: + # For lower-dimensional states, apply the encoder directly + return self._global_state_encoder(global_state) \ No newline at end of file