diff --git a/ding/model/wrapper/model_wrappers.py b/ding/model/wrapper/model_wrappers.py index e427587327..94f5b86ac4 100644 --- a/ding/model/wrapper/model_wrappers.py +++ b/ding/model/wrapper/model_wrappers.py @@ -866,10 +866,14 @@ def forward(self, *args, **kwargs): assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) if 'action' in output or 'action_args' in output: key = 'action' if 'action' in output else 'action_args' - action = output[key] + # handle hybrid action space by adding noise to continuous part of model output + action = output[key]['action_args'] if isinstance(output[key], dict) else output[key] assert isinstance(action, torch.Tensor) action = self.add_noise(action) - output[key] = action + if isinstance(output[key], dict): + output[key]['action_args'] = action + else: + output[key] = action return output def add_noise(self, action: torch.Tensor) -> torch.Tensor: