From 0228e51e8fab3a02c4e841cbbb09b7dc2c1b4e02 Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Mon, 23 Oct 2023 17:15:01 +0800 Subject: [PATCH] fix arena petting zoo import error format --- openrl/algorithms/dqn.py | 4 ++- openrl/algorithms/vdn.py | 4 ++- openrl/arena/__init__.py | 4 ++- openrl/envs/mpe/rendering.py | 10 +++--- openrl/envs/snake/snake.py | 4 ++- openrl/envs/vec_env/async_venv.py | 34 ++++++------------- openrl/utils/callbacks/checkpoint_callback.py | 4 +-- openrl/utils/evaluation.py | 10 +++--- 8 files changed, 32 insertions(+), 42 deletions(-) diff --git a/openrl/algorithms/dqn.py b/openrl/algorithms/dqn.py index bbca547b..ebd8d727 100644 --- a/openrl/algorithms/dqn.py +++ b/openrl/algorithms/dqn.py @@ -167,7 +167,9 @@ def prepare_loss( ) q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch - q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数 + q_loss = torch.mean( + F.mse_loss(q_values, q_targets.detach()) + ) # 均方误差损失函数 loss_list.append(q_loss) diff --git a/openrl/algorithms/vdn.py b/openrl/algorithms/vdn.py index f1215c03..83bdb5ed 100644 --- a/openrl/algorithms/vdn.py +++ b/openrl/algorithms/vdn.py @@ -211,7 +211,9 @@ def prepare_loss( rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1) rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1) q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch - q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数 + q_loss = torch.mean( + F.mse_loss(q_values, q_targets.detach()) + ) # 均方误差损失函数 loss_list.append(q_loss) return loss_list diff --git a/openrl/arena/__init__.py b/openrl/arena/__init__.py index 4bea924d..cb154a9f 100644 --- a/openrl/arena/__init__.py +++ b/openrl/arena/__init__.py @@ -30,9 +30,11 @@ def make_arena( **kwargs, ): if custom_build_env is None: + from openrl.envs import PettingZoo + if ( env_id in pettingzoo_all_envs - or env_id in openrl.envs.PettingZoo.registration.pettingzoo_env_dict.keys() + or env_id in PettingZoo.registration.pettingzoo_env_dict.keys() ): from openrl.envs.PettingZoo import make_PettingZoo_env diff --git a/openrl/envs/mpe/rendering.py b/openrl/envs/mpe/rendering.py index a7197dca..6dae5d66 100644 --- a/openrl/envs/mpe/rendering.py +++ b/openrl/envs/mpe/rendering.py @@ -31,12 +31,10 @@ except ImportError: print( "Error occured while running `from pyglet.gl import *`", - ( - "HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get" - " install python-opengl'. If you're running on a server, you may need a" - " virtual frame buffer; something like this should work: 'xvfb-run -s" - ' "-screen 0 1400x900x24" python \'' - ), + "HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get" + " install python-opengl'. If you're running on a server, you may need a" + " virtual frame buffer; something like this should work: 'xvfb-run -s" + ' "-screen 0 1400x900x24" python \'', ) import math diff --git a/openrl/envs/snake/snake.py b/openrl/envs/snake/snake.py index 73e81229..4a5be6a5 100644 --- a/openrl/envs/snake/snake.py +++ b/openrl/envs/snake/snake.py @@ -674,7 +674,9 @@ class Snake: def __init__(self, player_id, board_width, board_height, init_len): self.actions = [-2, 2, -1, 1] self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"} - self.direction = random.choice(self.actions) # 方向[-2,2,-1,1]分别表示[上,下,左,右] + self.direction = random.choice( + self.actions + ) # 方向[-2,2,-1,1]分别表示[上,下,左,右] self.board_width = board_width self.board_height = board_height x = random.randrange(0, board_height) diff --git a/openrl/envs/vec_env/async_venv.py b/openrl/envs/vec_env/async_venv.py index 54ab2c80..dd654599 100644 --- a/openrl/envs/vec_env/async_venv.py +++ b/openrl/envs/vec_env/async_venv.py @@ -234,10 +234,8 @@ def reset_send( if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - ( - "Calling `reset_send` while waiting for a pending call to" - f" `{self._state.value}` to complete" - ), + "Calling `reset_send` while waiting for a pending call to" + f" `{self._state.value}` to complete", self._state.value, ) @@ -329,10 +327,8 @@ def step_send(self, actions: np.ndarray): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - ( - "Calling `step_send` while waiting for a pending call to" - f" `{self._state.value}` to complete." - ), + "Calling `step_send` while waiting for a pending call to" + f" `{self._state.value}` to complete.", self._state.value, ) @@ -342,9 +338,7 @@ def step_send(self, actions: np.ndarray): pipe.send(("step", action)) self._state = AsyncState.WAITING_STEP - def step_fetch( - self, timeout: Optional[Union[int, float]] = None - ) -> Union[ + def step_fetch(self, timeout: Optional[Union[int, float]] = None) -> Union[ Tuple[Any, NDArray[Any], NDArray[Any], List[Dict[str, Any]]], Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], List[Dict[str, Any]]], ]: @@ -576,10 +570,8 @@ def call_send(self, name: str, *args, **kwargs): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - ( - "Calling `call_send` while waiting " - f"for a pending call to `{self._state.value}` to complete." - ), + "Calling `call_send` while waiting " + f"for a pending call to `{self._state.value}` to complete.", str(self._state.value), ) @@ -636,10 +628,8 @@ def exec_func_send(self, func: Callable, indices, *args, **kwargs): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - ( - "Calling `exec_func_send` while waiting " - f"for a pending call to `{self._state.value}` to complete." - ), + "Calling `exec_func_send` while waiting " + f"for a pending call to `{self._state.value}` to complete.", str(self._state.value), ) @@ -717,10 +707,8 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]): if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - ( - "Calling `set_attr` while waiting " - f"for a pending call to `{self._state.value}` to complete." - ), + "Calling `set_attr` while waiting " + f"for a pending call to `{self._state.value}` to complete.", str(self._state.value), ) diff --git a/openrl/utils/callbacks/checkpoint_callback.py b/openrl/utils/callbacks/checkpoint_callback.py index a4b3f5b6..56bf31b8 100644 --- a/openrl/utils/callbacks/checkpoint_callback.py +++ b/openrl/utils/callbacks/checkpoint_callback.py @@ -72,9 +72,7 @@ def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> st """ return os.path.join( self.save_path, - ( - f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}" - ), + f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}", ) def _on_step(self) -> bool: diff --git a/openrl/utils/evaluation.py b/openrl/utils/evaluation.py index d603daa5..391ba10f 100644 --- a/openrl/utils/evaluation.py +++ b/openrl/utils/evaluation.py @@ -68,12 +68,10 @@ def evaluate_policy( if not is_monitor_wrapped and warn: warnings.warn( - ( - "Evaluation environment is not wrapped with a ``Monitor`` wrapper. This" - " may result in reporting modified episode lengths and rewards, if" - " other wrappers happen to modify these. Consider wrapping environment" - " first with ``Monitor`` wrapper." - ), + "Evaluation environment is not wrapped with a ``Monitor`` wrapper. This" + " may result in reporting modified episode lengths and rewards, if" + " other wrappers happen to modify these. Consider wrapping environment" + " first with ``Monitor`` wrapper.", UserWarning, )