Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix arena petting zoo import error #258

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion openrl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def prepare_loss(
)

q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
q_loss = torch.mean(
F.mse_loss(q_values, q_targets.detach())
) # 均方误差损失函数

loss_list.append(q_loss)

Expand Down
4 changes: 3 additions & 1 deletion openrl/algorithms/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def prepare_loss(
rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1)
rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1)
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
q_loss = torch.mean(
F.mse_loss(q_values, q_targets.detach())
) # 均方误差损失函数

loss_list.append(q_loss)
return loss_list
Expand Down
4 changes: 3 additions & 1 deletion openrl/arena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def make_arena(
**kwargs,
):
if custom_build_env is None:
from openrl.envs import PettingZoo

if (
env_id in pettingzoo_all_envs
or env_id in openrl.envs.PettingZoo.registration.pettingzoo_env_dict.keys()
or env_id in PettingZoo.registration.pettingzoo_env_dict.keys()
):
from openrl.envs.PettingZoo import make_PettingZoo_env

Expand Down
10 changes: 4 additions & 6 deletions openrl/envs/mpe/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,10 @@
except ImportError:
print(
"Error occured while running `from pyglet.gl import *`",
(
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
" install python-opengl'. If you're running on a server, you may need a"
" virtual frame buffer; something like this should work: 'xvfb-run -s"
' "-screen 0 1400x900x24" python <your_script.py>\''
),
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
" install python-opengl'. If you're running on a server, you may need a"
" virtual frame buffer; something like this should work: 'xvfb-run -s"
' "-screen 0 1400x900x24" python <your_script.py>\'',
)

import math
Expand Down
4 changes: 3 additions & 1 deletion openrl/envs/snake/snake.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,9 @@ class Snake:
def __init__(self, player_id, board_width, board_height, init_len):
self.actions = [-2, 2, -1, 1]
self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"}
self.direction = random.choice(self.actions) # 方向[-2,2,-1,1]分别表示[上,下,左,右]
self.direction = random.choice(
self.actions
) # 方向[-2,2,-1,1]分别表示[上,下,左,右]
self.board_width = board_width
self.board_height = board_height
x = random.randrange(0, board_height)
Expand Down
34 changes: 11 additions & 23 deletions openrl/envs/vec_env/async_venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,8 @@ def reset_send(

if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `reset_send` while waiting for a pending call to"
f" `{self._state.value}` to complete"
),
"Calling `reset_send` while waiting for a pending call to"
f" `{self._state.value}` to complete",
self._state.value,
)

Expand Down Expand Up @@ -329,10 +327,8 @@ def step_send(self, actions: np.ndarray):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `step_send` while waiting for a pending call to"
f" `{self._state.value}` to complete."
),
"Calling `step_send` while waiting for a pending call to"
f" `{self._state.value}` to complete.",
self._state.value,
)

Expand All @@ -342,9 +338,7 @@ def step_send(self, actions: np.ndarray):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP

def step_fetch(
self, timeout: Optional[Union[int, float]] = None
) -> Union[
def step_fetch(self, timeout: Optional[Union[int, float]] = None) -> Union[
Tuple[Any, NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
]:
Expand Down Expand Up @@ -576,10 +570,8 @@ def call_send(self, name: str, *args, **kwargs):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `call_send` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
"Calling `call_send` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)

Expand Down Expand Up @@ -636,10 +628,8 @@ def exec_func_send(self, func: Callable, indices, *args, **kwargs):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `exec_func_send` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
"Calling `exec_func_send` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)

Expand Down Expand Up @@ -717,10 +707,8 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]):

if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)

Expand Down
4 changes: 1 addition & 3 deletions openrl/utils/callbacks/checkpoint_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> st
"""
return os.path.join(
self.save_path,
(
f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}"
),
f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}",
)

def _on_step(self) -> bool:
Expand Down
10 changes: 4 additions & 6 deletions openrl/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,10 @@ def evaluate_policy(

if not is_monitor_wrapped and warn:
warnings.warn(
(
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. This"
" may result in reporting modified episode lengths and rewards, if"
" other wrappers happen to modify these. Consider wrapping environment"
" first with ``Monitor`` wrapper."
),
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. This"
" may result in reporting modified episode lengths and rewards, if"
" other wrappers happen to modify these. Consider wrapping environment"
" first with ``Monitor`` wrapper.",
UserWarning,
)

Expand Down