Skip to content

Commit

Permalink
check format
Browse files Browse the repository at this point in the history
  • Loading branch information
WentseChen committed Mar 22, 2024
1 parent dae9cb5 commit d2564da
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 41 deletions.
12 changes: 7 additions & 5 deletions examples/crafter/render_crafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

import numpy as np

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.envs.wrappers import GIFWrapper
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent

from openrl.configs.config import create_config_parser

def render():
# begin to test
Expand All @@ -32,7 +32,7 @@ def render():
render_mode="human",
env_num=1,
)

# config
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args()
Expand All @@ -56,16 +56,18 @@ def render():

if all(done):
break
img = obs["policy"][0,0]

img = obs["policy"][0, 0]
img = img.transpose((1, 2, 0))
trajectory.append(img)

# save the trajectory to gif
import imageio

imageio.mimsave("run_results/crafter.gif", trajectory, duration=0.01)

env.close()


if __name__ == "__main__":
render()
3 changes: 2 additions & 1 deletion examples/crafter/train_crafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

import numpy as np

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.envs.wrappers import GIFWrapper
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent

from openrl.configs.config import create_config_parser

def train():
# create environment
Expand All @@ -43,5 +43,6 @@ def train():
env.close()
return agent


if __name__ == "__main__":
agent = train()
4 changes: 2 additions & 2 deletions openrl/envs/common/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def make(
)
elif id.startswith("Crafter"):
from openrl.envs.crafter import make_crafter_envs

env_fns = make_crafter_envs(
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
)

elif (
id in openrl.envs.pettingzoo_all_envs
or id in openrl.envs.PettingZoo.registration.pettingzoo_env_dict.keys()
Expand Down
1 change: 0 additions & 1 deletion openrl/envs/crafter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def make_crafter_envs(
render_mode: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> List[Callable[[], Env]]:

from openrl.envs.wrappers import (
AutoReset,
DictWrapper,
Expand Down
22 changes: 10 additions & 12 deletions openrl/envs/crafter/crafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@

from typing import Any, Dict, List, Optional, Union

import numpy as np
import crafter
import gymnasium as gym
import numpy as np
from gymnasium import Wrapper

import crafter


class CrafterWrapper(Wrapper):
def __init__(
Expand All @@ -33,19 +32,19 @@ def __init__(
disable_env_checker: Optional[bool] = None,
**kwargs
):

self.env_name = name

self.env = crafter.Env()
self.env = crafter.Recorder(
self.env, "run_results/crafter_traj",
save_stats=False, # if True, save the stats of the environment to example/crafter/crafter_traj
self.env,
"run_results/crafter_traj",
save_stats=False, # if True, save the stats of the environment to example/crafter/crafter_traj
save_episode=False,
save_video=False,
)

super().__init__(self.env)

shape = self.env.observation_space.shape
shape = (shape[2],) + shape[0:2]
self.observation_space = gym.spaces.Box(
Expand All @@ -59,7 +58,7 @@ def step(self, action: int):
obs = self.convert_observation(obs)

return obs, reward, done, truncated, info

def reset(
self,
seed: Optional[int] = None,
Expand All @@ -70,10 +69,9 @@ def reset(
obs = self.convert_observation(obs)

return obs, info

def convert_observation(self, observation: np.array):
obs = np.asarray(observation, dtype=np.uint8)
obs = obs.transpose((2, 0, 1))

return obs

10 changes: 6 additions & 4 deletions openrl/envs/mpe/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
except ImportError:
print(
"Error occured while running `from pyglet.gl import *`",
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
" install python-opengl'. If you're running on a server, you may need a"
" virtual frame buffer; something like this should work: 'xvfb-run -s"
' "-screen 0 1400x900x24" python <your_script.py>\'',
(
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
" install python-opengl'. If you're running on a server, you may need a"
" virtual frame buffer; something like this should work: 'xvfb-run -s"
' "-screen 0 1400x900x24" python <your_script.py>\''
),
)

import math
Expand Down
34 changes: 23 additions & 11 deletions openrl/envs/vec_env/async_venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,10 @@ def reset_send(

if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `reset_send` while waiting for a pending call to"
f" `{self._state.value}` to complete",
(
"Calling `reset_send` while waiting for a pending call to"
f" `{self._state.value}` to complete"
),
self._state.value,
)

Expand Down Expand Up @@ -327,8 +329,10 @@ def step_send(self, actions: np.ndarray):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `step_send` while waiting for a pending call to"
f" `{self._state.value}` to complete.",
(
"Calling `step_send` while waiting for a pending call to"
f" `{self._state.value}` to complete."
),
self._state.value,
)

Expand All @@ -338,7 +342,9 @@ def step_send(self, actions: np.ndarray):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP

def step_fetch(self, timeout: Optional[Union[int, float]] = None) -> Union[
def step_fetch(
self, timeout: Optional[Union[int, float]] = None
) -> Union[
Tuple[Any, NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
]:
Expand Down Expand Up @@ -570,8 +576,10 @@ def call_send(self, name: str, *args, **kwargs):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `call_send` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
(
"Calling `call_send` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
str(self._state.value),
)

Expand Down Expand Up @@ -628,8 +636,10 @@ def exec_func_send(self, func: Callable, indices, *args, **kwargs):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `exec_func_send` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
(
"Calling `exec_func_send` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
str(self._state.value),
)

Expand Down Expand Up @@ -707,8 +717,10 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]):

if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
(
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
str(self._state.value),
)

Expand Down
4 changes: 3 additions & 1 deletion openrl/utils/callbacks/checkpoint_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> st
"""
return os.path.join(
self.save_path,
f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}",
(
f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}"
),
)

def _on_step(self) -> bool:
Expand Down
10 changes: 6 additions & 4 deletions openrl/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ def evaluate_policy(

if not is_monitor_wrapped and warn:
warnings.warn(
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. This"
" may result in reporting modified episode lengths and rewards, if"
" other wrappers happen to modify these. Consider wrapping environment"
" first with ``Monitor`` wrapper.",
(
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. This"
" may result in reporting modified episode lengths and rewards, if"
" other wrappers happen to modify these. Consider wrapping environment"
" first with ``Monitor`` wrapper."
),
UserWarning,
)

Expand Down

0 comments on commit d2564da

Please sign in to comment.