Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove AtariEnv #334

Merged
merged 12 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions hive/configs/atari/dqn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,20 @@ kwargs:
max_steps_per_episode: 27000
stack_size: &stack_size 4
environment:
name: 'AtariEnv'
name: 'GymEnv'
kwargs:
env_name: 'Asterix'
env_name: 'ALE/Asterix-v5'
repeat_action_probability: 0.25 # probality = 1 / sticky_action_number
frameskip: 1 # this is required to be set to expose frameskip in wrapper
env_wrappers:
- name: AtariPreprocessing
kwargs:
frame_skip: 4
grayscale_newaxis: True
screen_size: 84
# This wrapper converts image from the TF format (channels last)
# to the PyTorch format (channels first)
- name: PermuteImageWrapper

agent:
name: 'DQNAgent'
Expand Down
2 changes: 1 addition & 1 deletion hive/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from hive.envs.base import BaseEnv, ParallelEnv
from hive.envs.env_spec import EnvSpec
from hive.envs.gym_env import GymEnv
from hive.envs.gym.gym_env import GymEnv

try:
from hive.envs.minigrid import MiniGridEnv
Expand Down
1 change: 0 additions & 1 deletion hive/envs/atari/__init__.py

This file was deleted.

136 changes: 0 additions & 136 deletions hive/envs/atari/atari.py

This file was deleted.

2 changes: 0 additions & 2 deletions hive/envs/atari/requirements.txt

This file was deleted.

3 changes: 2 additions & 1 deletion hive/envs/env_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from hive.utils.registry import Registrable
import gymnasium as gym


class EnvWrapper(Registrable):
class GymWrapper(Registrable, gym.core.Wrapper):
"""A wrapper for callables that produce environment wrappers.

These wrapped callables can be partially initialized through configuration
Expand Down
1 change: 1 addition & 0 deletions hive/envs/gym/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from hive.envs.gym.gym_wrappers import FlattenWrapper, PermuteImageWrapper
8 changes: 4 additions & 4 deletions hive/envs/gym_env.py → hive/envs/gym/gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from hive.envs.base import BaseEnv
from hive.envs.env_spec import EnvSpec
from hive.envs.env_wrapper import EnvWrapper, apply_wrappers
from hive.envs.env_wrapper import GymWrapper, apply_wrappers
from hive.utils.registry import registry


Expand All @@ -18,7 +18,7 @@ class GymEnv(BaseEnv):
def __init__(
self,
env_name: str,
env_wrappers: List[EnvWrapper] = None,
env_wrappers: List[GymWrapper] = None,
num_players: int = 1,
render_mode: str = None,
**kwargs
Expand All @@ -27,7 +27,7 @@ def __init__(
Args:
env_name (str): Name of the environment (NOTE: make sure it is available
in gym.envs.registry.all())
env_wrappers (List[EnvWrapper]): List of environment wrappers to apply.
env_wrappers (List[GymWrapper]): List of environment wrappers to apply.
num_players (int): Number of players for the environment.
render_mode (str): One of None, "human", "rgb_array", "ansi", or
"rgb_array_list". See gym documentation for details.
Expand Down Expand Up @@ -101,6 +101,6 @@ def close(self):
]

registry.register_all(
EnvWrapper,
GymWrapper,
{wrapper.__name__: wrapper for wrapper in wrappers},
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import gymnasium as gym
import numpy as np

from hive.utils.registry import registry
from hive.envs.env_wrapper import GymWrapper


class FlattenWrapper(gym.core.ObservationWrapper):
"""
Expand Down Expand Up @@ -76,3 +79,9 @@ def observation(self, obs):
return tuple(np.transpose(o, [2, 1, 0]) for o in obs)
else:
return np.transpose(obs, [2, 1, 0])


registry.register_all(
GymWrapper,
{"PermuteImageWrapper": PermuteImageWrapper, "FlattenWrapper": FlattenWrapper},
)
2 changes: 1 addition & 1 deletion hive/envs/marlgrid/marlgrid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gym
import numpy as np
from hive.envs import GymEnv, ParallelEnv
from hive.envs.wrappers.gym_wrappers import FlattenWrapper, PermuteImageWrapper
from hive.envs.gym.gym_wrappers import FlattenWrapper, PermuteImageWrapper
from marlgrid import envs
from gym.wrappers.compatibility import EnvCompatibility

Expand Down
4 changes: 2 additions & 2 deletions hive/envs/minigrid/minigrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
RGBImgPartialObsWrapper,
)

from hive.envs.gym_env import GymEnv
from hive.envs.wrappers.gym_wrappers import FlattenWrapper, PermuteImageWrapper
from hive.envs.gym.gym_env import GymEnv
from hive.envs.gym.gym_wrappers import FlattenWrapper, PermuteImageWrapper


class MiniGridEnv(GymEnv):
Expand Down
1 change: 0 additions & 1 deletion hive/envs/wrappers/__init__.py

This file was deleted.

38 changes: 33 additions & 5 deletions tests/hive/envs/test_atari_env.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
import numpy as np
import pytest
from functools import partial
from hive.envs import GymEnv
from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing
from hive.envs.gym.gym_wrappers import PermuteImageWrapper

from hive.envs.atari import AtariEnv

test_env_configs = [("Pong", 4, 84), ("Breakout", 1, 100)]
test_env_configs = [("ALE/Pong-v5", 4, 84), ("ALE/Asterix-v5", 1, 100)]


@pytest.mark.parametrize("env_name,frame_skip,screen_size", test_env_configs)
def test_reset_func(env_name, frame_skip, screen_size):
hive_env = AtariEnv(env_name, frame_skip, screen_size)
hive_env = GymEnv(
env_name,
repeat_action_probability=0.25,
frameskip=1,
env_wrappers=[
partial(
AtariPreprocessing,
frame_skip=frame_skip,
screen_size=screen_size,
grayscale_newaxis=True,
),
PermuteImageWrapper,
],
)
hive_observation, hive_turn = hive_env.reset()

assert isinstance(hive_observation, np.ndarray)
Expand All @@ -22,7 +37,20 @@ def test_reset_func(env_name, frame_skip, screen_size):

@pytest.mark.parametrize("env_name,frame_skip,screen_size", test_env_configs)
def test_step_func(env_name, frame_skip, screen_size):
hive_env = AtariEnv(env_name, frame_skip, screen_size)
hive_env = GymEnv(
env_name,
repeat_action_probability=0.25,
frameskip=1,
env_wrappers=[
partial(
AtariPreprocessing,
frame_skip=frame_skip,
screen_size=screen_size,
grayscale_newaxis=True,
),
PermuteImageWrapper,
],
)
for action in range(hive_env.env_spec.action_space[0].n):
hive_env.reset()
(
Expand Down
2 changes: 1 addition & 1 deletion tests/hive/envs/test_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pytest

from hive.envs.gym_env import GymEnv
from hive.envs.gym.gym_env import GymEnv

test_environments = ["CartPole-v0", "MountainCar-v0"]

Expand Down