diff --git a/ding/model/template/qmix.py b/ding/model/template/qmix.py index d2a4eefc92..b5cde0806b 100644 --- a/ding/model/template/qmix.py +++ b/ding/model/template/qmix.py @@ -167,9 +167,9 @@ def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) -> Overview: Determine the type of global observation shape. Arguments: - - global_obs_shape (Union[:obj:`int`, :obj:`List[int]`]): The global observation state. + - global_obs_shape (:obj:`Union[int, List[int]]`): The global observation state. Returns: - - (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation. + - obs_shape_type (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation. """ if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1): return "flat" @@ -265,7 +265,7 @@ def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor: - global_state (:obj:`torch.Tensor`): The global state tensor. Returns: - - (:obj:`torch.Tensor`): The processed global state embedding. + - global_state_embedding (:obj:`torch.Tensor`): The processed global state embedding. """ # If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W] if global_state.dim() == 5: