From d2564daf34ad9b334a1c0eedaf2d317557aca0e3 Mon Sep 17 00:00:00 2001 From: WentseChen Date: Thu, 21 Mar 2024 22:31:50 -0400 Subject: [PATCH] check format --- examples/crafter/render_crafter.py | 12 ++++--- examples/crafter/train_crafter.py | 3 +- openrl/envs/common/registration.py | 4 +-- openrl/envs/crafter/__init__.py | 1 - openrl/envs/crafter/crafter.py | 22 ++++++------ openrl/envs/mpe/rendering.py | 10 +++--- openrl/envs/vec_env/async_venv.py | 34 +++++++++++++------ openrl/utils/callbacks/checkpoint_callback.py | 4 ++- openrl/utils/evaluation.py | 10 +++--- 9 files changed, 59 insertions(+), 41 deletions(-) diff --git a/examples/crafter/render_crafter.py b/examples/crafter/render_crafter.py index d6e4844a..a093b53b 100644 --- a/examples/crafter/render_crafter.py +++ b/examples/crafter/render_crafter.py @@ -18,12 +18,12 @@ import numpy as np +from openrl.configs.config import create_config_parser from openrl.envs.common import make from openrl.envs.wrappers import GIFWrapper from openrl.modules.common import PPONet as Net from openrl.runners.common import PPOAgent as Agent -from openrl.configs.config import create_config_parser def render(): # begin to test @@ -32,7 +32,7 @@ def render(): render_mode="human", env_num=1, ) - + # config cfg_parser = create_config_parser() cfg = cfg_parser.parse_args() @@ -56,16 +56,18 @@ def render(): if all(done): break - - img = obs["policy"][0,0] + + img = obs["policy"][0, 0] img = img.transpose((1, 2, 0)) trajectory.append(img) - + # save the trajectory to gif import imageio + imageio.mimsave("run_results/crafter.gif", trajectory, duration=0.01) env.close() + if __name__ == "__main__": render() diff --git a/examples/crafter/train_crafter.py b/examples/crafter/train_crafter.py index a1c2fb53..86a4006c 100644 --- a/examples/crafter/train_crafter.py +++ b/examples/crafter/train_crafter.py @@ -18,12 +18,12 @@ import numpy as np +from openrl.configs.config import create_config_parser from openrl.envs.common import make from openrl.envs.wrappers import GIFWrapper from openrl.modules.common import PPONet as Net from openrl.runners.common import PPOAgent as Agent -from openrl.configs.config import create_config_parser def train(): # create environment @@ -43,5 +43,6 @@ def train(): env.close() return agent + if __name__ == "__main__": agent = train() diff --git a/openrl/envs/common/registration.py b/openrl/envs/common/registration.py index 104a69e1..dc177c68 100644 --- a/openrl/envs/common/registration.py +++ b/openrl/envs/common/registration.py @@ -150,11 +150,11 @@ def make( ) elif id.startswith("Crafter"): from openrl.envs.crafter import make_crafter_envs - + env_fns = make_crafter_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) - + elif ( id in openrl.envs.pettingzoo_all_envs or id in openrl.envs.PettingZoo.registration.pettingzoo_env_dict.keys() diff --git a/openrl/envs/crafter/__init__.py b/openrl/envs/crafter/__init__.py index b02bf1e4..3a291eae 100644 --- a/openrl/envs/crafter/__init__.py +++ b/openrl/envs/crafter/__init__.py @@ -30,7 +30,6 @@ def make_crafter_envs( render_mode: Optional[Union[str, List[str]]] = None, **kwargs, ) -> List[Callable[[], Env]]: - from openrl.envs.wrappers import ( AutoReset, DictWrapper, diff --git a/openrl/envs/crafter/crafter.py b/openrl/envs/crafter/crafter.py index cca847d4..c03d8426 100644 --- a/openrl/envs/crafter/crafter.py +++ b/openrl/envs/crafter/crafter.py @@ -18,12 +18,11 @@ from typing import Any, Dict, List, Optional, Union -import numpy as np +import crafter import gymnasium as gym +import numpy as np from gymnasium import Wrapper -import crafter - class CrafterWrapper(Wrapper): def __init__( @@ -33,19 +32,19 @@ def __init__( disable_env_checker: Optional[bool] = None, **kwargs ): - self.env_name = name - + self.env = crafter.Env() self.env = crafter.Recorder( - self.env, "run_results/crafter_traj", - save_stats=False, # if True, save the stats of the environment to example/crafter/crafter_traj + self.env, + "run_results/crafter_traj", + save_stats=False, # if True, save the stats of the environment to example/crafter/crafter_traj save_episode=False, save_video=False, ) - + super().__init__(self.env) - + shape = self.env.observation_space.shape shape = (shape[2],) + shape[0:2] self.observation_space = gym.spaces.Box( @@ -59,7 +58,7 @@ def step(self, action: int): obs = self.convert_observation(obs) return obs, reward, done, truncated, info - + def reset( self, seed: Optional[int] = None, @@ -70,10 +69,9 @@ def reset( obs = self.convert_observation(obs) return obs, info - + def convert_observation(self, observation: np.array): obs = np.asarray(observation, dtype=np.uint8) obs = obs.transpose((2, 0, 1)) return obs - diff --git a/openrl/envs/mpe/rendering.py b/openrl/envs/mpe/rendering.py index 6dae5d66..a7197dca 100644 --- a/openrl/envs/mpe/rendering.py +++ b/openrl/envs/mpe/rendering.py @@ -31,10 +31,12 @@ except ImportError: print( "Error occured while running `from pyglet.gl import *`", - "HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get" - " install python-opengl'. If you're running on a server, you may need a" - " virtual frame buffer; something like this should work: 'xvfb-run -s" - ' "-screen 0 1400x900x24" python \'', + ( + "HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get" + " install python-opengl'. If you're running on a server, you may need a" + " virtual frame buffer; something like this should work: 'xvfb-run -s" + ' "-screen 0 1400x900x24" python \'' + ), ) import math diff --git a/openrl/envs/vec_env/async_venv.py b/openrl/envs/vec_env/async_venv.py index 141532ba..e4f10d2b 100644 --- a/openrl/envs/vec_env/async_venv.py +++ b/openrl/envs/vec_env/async_venv.py @@ -234,8 +234,10 @@ def reset_send( if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `reset_send` while waiting for a pending call to" - f" `{self._state.value}` to complete", + ( + "Calling `reset_send` while waiting for a pending call to" + f" `{self._state.value}` to complete" + ), self._state.value, ) @@ -327,8 +329,10 @@ def step_send(self, actions: np.ndarray): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `step_send` while waiting for a pending call to" - f" `{self._state.value}` to complete.", + ( + "Calling `step_send` while waiting for a pending call to" + f" `{self._state.value}` to complete." + ), self._state.value, ) @@ -338,7 +342,9 @@ def step_send(self, actions: np.ndarray): pipe.send(("step", action)) self._state = AsyncState.WAITING_STEP - def step_fetch(self, timeout: Optional[Union[int, float]] = None) -> Union[ + def step_fetch( + self, timeout: Optional[Union[int, float]] = None + ) -> Union[ Tuple[Any, NDArray[Any], NDArray[Any], List[Dict[str, Any]]], Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], List[Dict[str, Any]]], ]: @@ -570,8 +576,10 @@ def call_send(self, name: str, *args, **kwargs): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `call_send` while waiting " - f"for a pending call to `{self._state.value}` to complete.", + ( + "Calling `call_send` while waiting " + f"for a pending call to `{self._state.value}` to complete." + ), str(self._state.value), ) @@ -628,8 +636,10 @@ def exec_func_send(self, func: Callable, indices, *args, **kwargs): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `exec_func_send` while waiting " - f"for a pending call to `{self._state.value}` to complete.", + ( + "Calling `exec_func_send` while waiting " + f"for a pending call to `{self._state.value}` to complete." + ), str(self._state.value), ) @@ -707,8 +717,10 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]): if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `set_attr` while waiting " - f"for a pending call to `{self._state.value}` to complete.", + ( + "Calling `set_attr` while waiting " + f"for a pending call to `{self._state.value}` to complete." + ), str(self._state.value), ) diff --git a/openrl/utils/callbacks/checkpoint_callback.py b/openrl/utils/callbacks/checkpoint_callback.py index 56bf31b8..a4b3f5b6 100644 --- a/openrl/utils/callbacks/checkpoint_callback.py +++ b/openrl/utils/callbacks/checkpoint_callback.py @@ -72,7 +72,9 @@ def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> st """ return os.path.join( self.save_path, - f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}", + ( + f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}" + ), ) def _on_step(self) -> bool: diff --git a/openrl/utils/evaluation.py b/openrl/utils/evaluation.py index c008c437..7d307f1c 100644 --- a/openrl/utils/evaluation.py +++ b/openrl/utils/evaluation.py @@ -68,10 +68,12 @@ def evaluate_policy( if not is_monitor_wrapped and warn: warnings.warn( - "Evaluation environment is not wrapped with a ``Monitor`` wrapper. This" - " may result in reporting modified episode lengths and rewards, if" - " other wrappers happen to modify these. Consider wrapping environment" - " first with ``Monitor`` wrapper.", + ( + "Evaluation environment is not wrapped with a ``Monitor`` wrapper. This" + " may result in reporting modified episode lengths and rewards, if" + " other wrappers happen to modify these. Consider wrapping environment" + " first with ``Monitor`` wrapper." + ), UserWarning, )