Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gymnasium environment #119

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions nnabla_rl/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,14 +14,16 @@
# limitations under the License.

from gym.envs.registration import register
from gymnasium.envs.registration import register as gymnasium_register

from nnabla_rl.environments.dummy import (DummyAtariEnv, DummyContinuous, DummyContinuousActionGoalEnv, DummyDiscrete, # noqa
DummyDiscreteActionGoalEnv, DummyDiscreteImg, DummyContinuousImg,
DummyFactoredContinuous, DummyMujocoEnv,
DummyTupleContinuous, DummyTupleDiscrete, DummyTupleMixed,
DummyTupleStateContinuous, DummyTupleStateDiscrete,
DummyTupleActionContinuous, DummyTupleActionDiscrete,
DummyHybridEnv)
DummyHybridEnv,
DummyGymnasiumAtariEnv, DummyGymnasiumMujocoEnv)

register(
id='FakeMujocoNNablaRL-v1',
Expand Down Expand Up @@ -87,3 +89,16 @@
entry_point='nnabla_rl.environments.dummy:DummyHybridEnv',
max_episode_steps=10
)


gymnasium_register(
id='FakeGymnasiumMujocoNNablaRL-v1',
entry_point='nnabla_rl.environments.dummy:DummyGymnasiumMujocoEnv',
max_episode_steps=10
)

gymnasium_register(
id='FakeGymnasiumAtariNNablaRLNoFrameskip-v1',
entry_point='nnabla_rl.environments.dummy:DummyGymnasiumAtariEnv',
max_episode_steps=10
)
89 changes: 88 additions & 1 deletion nnabla_rl/environments/dummy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,8 +16,10 @@
from typing import TYPE_CHECKING, cast

import gym
import gymnasium
import numpy as np
from gym.envs.registration import EnvSpec
from gymnasium.envs.registration import EnvSpec as GymnasiumEnvSpec

if TYPE_CHECKING:
from gym.utils.seeding import RandomNumberGenerator
Expand Down Expand Up @@ -309,3 +311,88 @@ def __init__(self, max_episode_steps=None):
super(DummyHybridEnv, self).__init__(max_episode_steps=max_episode_steps)
self.action_space = gym.spaces.Tuple((gym.spaces.Discrete(5), gym.spaces.Box(low=0.0, high=1.0, shape=(5, ))))
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(5, ))


# =========== gymnasium ==========
class AbstractDummyGymnasiumEnv(gymnasium.Env):
def __init__(self, max_episode_steps):
self.spec = GymnasiumEnvSpec('dummy-v0', max_episode_steps=max_episode_steps)
self._episode_steps = 0

def reset(self):
self._episode_steps = 0
return self.observation_space.sample(), {}

def step(self, a):
next_state = self.observation_space.sample()
reward = np.random.randn()
terminated = False
if self.spec.max_episode_steps is None:
truncated = False
else:
truncated = bool(self._episode_steps < self.spec.max_episode_steps)
info = {'rnn_states': {'dummy_scope': {'dummy_state1': 1, 'dummy_state2': 2}}}
self._episode_steps += 1
return next_state, reward, terminated, truncated, info


class DummyGymnasiumAtariEnv(AbstractDummyGymnasiumEnv):
class DummyALE(object):
def __init__(self):
self._lives = 100

def lives(self):
self._lives -= 1
if self._lives < 0:
self._lives = 100
return self._lives

# seeding.np_random outputs np_random and seed
np_random = cast("RandomNumberGenerator", nnabla_rl.random.drng)

def __init__(self, done_at_random=True, max_episode_length=None):
super(DummyGymnasiumAtariEnv, self).__init__(
max_episode_steps=max_episode_length)
self.action_space = gymnasium.spaces.Discrete(4)
self.observation_space = gymnasium.spaces.Box(
low=0, high=255, shape=(84, 84, 3), dtype=np.uint8)
self.ale = DummyGymnasiumAtariEnv.DummyALE()
self._done_at_random = done_at_random
self._max_episode_length = max_episode_length
self._episode_length = None

def step(self, action):
assert self._episode_length is not None
observation = self.observation_space.sample()
self._episode_length += 1
if self._done_at_random:
done = bool(self.np_random.integers(10) == 0)
else:
done = False
if self._max_episode_length is not None:
done = (self._max_episode_length <= self._episode_length) or done
return observation, 1.0, done, {'needs_reset': False}

def reset(self):
self._episode_length = 0
return self.observation_space.sample()

def get_action_meanings(self):
return ['NOOP', 'FIRE', 'LEFT', 'RIGHT']


class DummyGymnasiumMujocoEnv(AbstractDummyGymnasiumEnv):
def __init__(self, max_episode_steps=None):
super(DummyGymnasiumMujocoEnv, self).__init__(max_episode_steps=max_episode_steps)
self.action_space = gymnasium.spaces.Box(low=0.0, high=1.0, shape=(5, ))
self.observation_space = gymnasium.spaces.Box(low=0.0, high=1.0, shape=(5, ))

def get_dataset(self):
dataset = {}
datasize = 2000
dataset['observations'] = np.stack([self.observation_space.sample() for _ in range(datasize)], axis=0)
dataset['actions'] = np.stack([self.action_space.sample() for _ in range(datasize)], axis=0)
dataset['rewards'] = np.random.randn(datasize, 1)
dataset['terminals'] = np.random.randint(2, size=(datasize, 1))
dataset['timeouts'] = np.zeros((datasize, 1))
return dataset
3 changes: 2 additions & 1 deletion nnabla_rl/environments/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,3 +20,4 @@
from nnabla_rl.environments.wrappers.atari import make_atari, wrap_deepmind # noqa
from nnabla_rl.environments.wrappers.hybrid_env import (EmbedActionWrapper, FlattenActionWrapper, # noqa
RemoveStepWrapper, ScaleActionWrapper, ScaleStateWrapper)
from nnabla_rl.environments.wrappers.gymnasium import Gymnasium2GymWrapper # noqa
14 changes: 11 additions & 3 deletions nnabla_rl/environments/wrappers/atari.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,10 +16,12 @@

import cv2
import gym
import gymnasium
import numpy as np
from gym import spaces

import nnabla_rl as rl
from nnabla_rl.environments.wrappers.gymnasium import Gymnasium2GymWrapper
from nnabla_rl.external.atari_wrappers import (ClipRewardEnv, EpisodicLifeEnv, FireResetEnv, MaxAndSkipEnv,
NoopResetEnv, ScaledFloatFrame)

Expand Down Expand Up @@ -97,8 +99,14 @@ def __array__(self, dtype=None):
return out


def make_atari(env_id, max_frames_per_episode=None):
env = gym.make(env_id)
def make_atari(env_id, max_frames_per_episode=None, use_gymnasium=False):
if use_gymnasium:
env = gymnasium.make(env_id)
env = Gymnasium2GymWrapper(env)
# gymnasium env is not wrapped TimeLimit wrapper
env = gym.wrappers.TimeLimit(env, max_episode_steps=env.spec.kwargs["max_num_frames_per_episode"])
else:
env = gym.make(env_id)
if max_frames_per_episode is not None:
env = env.unwrapped
env = gym.wrappers.TimeLimit(env, max_episode_steps=max_frames_per_episode)
Expand Down
88 changes: 88 additions & 0 deletions nnabla_rl/environments/wrappers/gymnasium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import gym
from gym import spaces as gym_spaces
from gymnasium import spaces as gymnasium_spaces
from gymnasium.utils import seeding


class Gymnasium2GymWrapper(gym.Wrapper):
def __init__(self, env):
if isinstance(env, gym.Env) or isinstance(env, gym.Wrapper):
raise ValueError("'env' should not be an instance of 'gym.Env' and 'gym.Wrapper'")

super().__init__(env)

# observation space
if isinstance(env.observation_space, gymnasium_spaces.Tuple):
self.observation_space = gym_spaces.Tuple(
[self._translate_space(observation_space)
for observation_space in env.observation_space]
)
elif isinstance(env.observation_space, gymnasium_spaces.Dict):
self.observation_space = gym_spaces.Dict(
{key: self._translate_space(observation_space)
for key, observation_space in env.observation_space.items()}
)
else:
self.observation_space = self._translate_space(env.observation_space)

# action space
if isinstance(env.action_space, gymnasium_spaces.Tuple):
self.action_space = gym_spaces.Tuple(
[self._translate_space(action_space)
for action_space in env.action_space]
)
elif isinstance(env.action_space, gymnasium_spaces.Dict):
self.action_space = gym_spaces.Dict(
{key: self._translate_space(action_space)
for key, action_space in env.action_space.items()}
)
else:
self.action_space = self._translate_space(env.action_space)

def reset(self):
obs, _ = self.env.reset()
return obs

def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
done = (terminated or truncated)
info.update({"TimeLimit.truncated": truncated})
return obs, reward, done, info

def seed(self, seed: Optional[int] = None):
np_random, seed = seeding.np_random(seed)
self.env.np_random = np_random # type: ignore
return [seed]

@property
def unwrapped(self):
return self

def _translate_space(self, space):
if isinstance(space, gymnasium_spaces.Box):
return gym_spaces.Box(
low=space.low,
high=space.high,
shape=space.shape,
dtype=space.dtype
)
elif isinstance(space, gymnasium_spaces.Discrete):
return gym_spaces.Discrete(n=int(space.n))
else:
raise NotImplementedError
23 changes: 17 additions & 6 deletions nnabla_rl/utils/reproductions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,12 +16,14 @@
import random as py_random

import gym
import gymnasium
import numpy as np

import nnabla as nn
import nnabla_rl as rl
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.environments.wrappers import NumpyFloat32Env, ScreenRenderEnv, make_atari, wrap_deepmind
from nnabla_rl.environments.wrappers import (Gymnasium2GymWrapper, NumpyFloat32Env, ScreenRenderEnv, make_atari,
wrap_deepmind)
from nnabla_rl.logger import logger


Expand Down Expand Up @@ -53,11 +55,14 @@ def build_atari_env(id_or_env,
print_info=True,
max_frames_per_episode=None,
frame_stack=True,
flicker_probability=0.0):
flicker_probability=0.0,
use_gymnasium=False):
if isinstance(id_or_env, gym.Env):
env = id_or_env
elif isinstance(id_or_env, gymnasium.Env):
env = Gymnasium2GymWrapper(id_or_env)
else:
env = make_atari(id_or_env, max_frames_per_episode=max_frames_per_episode)
env = make_atari(id_or_env, max_frames_per_episode=max_frames_per_episode, use_gymnasium=use_gymnasium)
if print_info:
print_env_info(env)

Expand All @@ -75,7 +80,7 @@ def build_atari_env(id_or_env,
return env


def build_mujoco_env(id_or_env, test=False, seed=None, render=False, print_info=True):
def build_mujoco_env(id_or_env, test=False, seed=None, render=False, print_info=True, use_gymnasium=False):
try:
# Add pybullet env
import pybullet_envs # noqa
Expand All @@ -91,8 +96,14 @@ def build_mujoco_env(id_or_env, test=False, seed=None, render=False, print_info=

if isinstance(id_or_env, gym.Env):
env = id_or_env
elif isinstance(id_or_env, gymnasium.Env):
env = Gymnasium2GymWrapper(id_or_env)
else:
env = gym.make(id_or_env)
if use_gymnasium:
env = gymnasium.make(id_or_env)
env = Gymnasium2GymWrapper(env)
else:
env = gym.make(id_or_env)

if print_info:
print_env_info(env)
Expand Down
10 changes: 6 additions & 4 deletions reproductions/algorithms/atari/a2c/a2c_reproduction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,9 +28,9 @@

def run_training(args):
set_global_seed(args.seed)
train_env = build_atari_env(args.env, seed=args.seed)
train_env = build_atari_env(args.env, seed=args.seed, use_gymnasium=args.use_gymnasium)
eval_env = build_atari_env(
args.env, test=True, seed=args.seed + 100, render=args.render)
args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium)

iteration_num_hook = H.IterationNumHook(timing=100)

Expand All @@ -56,7 +56,8 @@ def run_training(args):
def run_showcase(args):
if args.snapshot_dir is None:
raise ValueError('Please specify the snapshot dir for showcasing')
eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=False)
eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=False,
use_gymnasium=args.use_gymnasium)
config = A.A2CConfig(gpu_id=args.gpu)
a2c = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config})
if not isinstance(a2c, A.A2C):
Expand All @@ -83,6 +84,7 @@ def main():
parser.add_argument('--save_timing', type=int, default=250000)
parser.add_argument('--eval_timing', type=int, default=250000)
parser.add_argument('--showcase_runs', type=int, default=10)
parser.add_argument('--use-gymnasium', action='store_true')

args = parser.parse_args()

Expand Down
Loading
Loading