From c7b83f8aa8ff597045762b76fa8a63065165f609 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 8 Jul 2024 13:56:11 +0200 Subject: [PATCH] Improve interface of BasePolicy.compute_action #1169 --- tianshou/policy/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index b7ae5f23d..6b32637fd 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -10,6 +10,7 @@ import torch from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete from numba import njit +from numpy._typing import ArrayLike from overrides import override from torch import nn @@ -289,7 +290,7 @@ def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None: def compute_action( self, - obs: arr_type, + obs: ArrayLike, info: dict[str, Any] | None = None, state: dict | BatchProtocol | np.ndarray | None = None, ) -> np.ndarray | int: @@ -300,8 +301,8 @@ def compute_action( :param state: the hidden state of RNN policy, used for recurrent policy. :return: action as int (for discrete env's) or array (for continuous ones). """ - # need to add empty batch dimension - obs = obs[None, :] + obs = np.array(obs) # convert array-like to array (e.g. LazyFrames) + obs = obs[None, :] # need to add empty batch dimension obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info)) act = self.forward(obs_batch, state=state).act.squeeze() if isinstance(act, torch.Tensor):