Skip to content

Commit

Permalink
polish(pu): polish qmix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 25, 2024
1 parent 4f08f91 commit 16da46e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions ding/model/template/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) ->
Overview:
Determine the type of global observation shape.
Arguments:
- global_obs_shape (Union[:obj:`int`, :obj:`List[int]`]): The global observation state.
- global_obs_shape (:obj:`Union[int, List[int]]`): The global observation state.
Returns:
- (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation.
- obs_shape_type (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation.
"""
if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1):
return "flat"
Expand Down Expand Up @@ -265,7 +265,7 @@ def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor:
- global_state (:obj:`torch.Tensor`): The global state tensor.
Returns:
- (:obj:`torch.Tensor`): The processed global state embedding.
- global_state_embedding (:obj:`torch.Tensor`): The processed global state embedding.
"""
# If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W]
if global_state.dim() == 5:
Expand Down

0 comments on commit 16da46e

Please sign in to comment.