Skip to content

Commit

Permalink
Merge pull request #119 from sony/feature/20240129-support-gymnasium
Browse files Browse the repository at this point in the history
Support gymnasium environment
  • Loading branch information
ishihara-y authored Mar 27, 2024
2 parents 2b4b7e8 + d763e30 commit b98c747
Show file tree
Hide file tree
Showing 33 changed files with 462 additions and 113 deletions.
19 changes: 17 additions & 2 deletions nnabla_rl/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,14 +14,16 @@
# limitations under the License.

from gym.envs.registration import register
from gymnasium.envs.registration import register as gymnasium_register

from nnabla_rl.environments.dummy import (DummyAtariEnv, DummyContinuous, DummyContinuousActionGoalEnv, DummyDiscrete, # noqa
DummyDiscreteActionGoalEnv, DummyDiscreteImg, DummyContinuousImg,
DummyFactoredContinuous, DummyMujocoEnv,
DummyTupleContinuous, DummyTupleDiscrete, DummyTupleMixed,
DummyTupleStateContinuous, DummyTupleStateDiscrete,
DummyTupleActionContinuous, DummyTupleActionDiscrete,
DummyHybridEnv)
DummyHybridEnv,
DummyGymnasiumAtariEnv, DummyGymnasiumMujocoEnv)

register(
id='FakeMujocoNNablaRL-v1',
Expand Down Expand Up @@ -87,3 +89,16 @@
entry_point='nnabla_rl.environments.dummy:DummyHybridEnv',
max_episode_steps=10
)


gymnasium_register(
id='FakeGymnasiumMujocoNNablaRL-v1',
entry_point='nnabla_rl.environments.dummy:DummyGymnasiumMujocoEnv',
max_episode_steps=10
)

gymnasium_register(
id='FakeGymnasiumAtariNNablaRLNoFrameskip-v1',
entry_point='nnabla_rl.environments.dummy:DummyGymnasiumAtariEnv',
max_episode_steps=10
)
89 changes: 88 additions & 1 deletion nnabla_rl/environments/dummy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,8 +16,10 @@
from typing import TYPE_CHECKING, cast

import gym
import gymnasium
import numpy as np
from gym.envs.registration import EnvSpec
from gymnasium.envs.registration import EnvSpec as GymnasiumEnvSpec

if TYPE_CHECKING:
from gym.utils.seeding import RandomNumberGenerator
Expand Down Expand Up @@ -309,3 +311,88 @@ def __init__(self, max_episode_steps=None):
super(DummyHybridEnv, self).__init__(max_episode_steps=max_episode_steps)
self.action_space = gym.spaces.Tuple((gym.spaces.Discrete(5), gym.spaces.Box(low=0.0, high=1.0, shape=(5, ))))
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(5, ))


# =========== gymnasium ==========
class AbstractDummyGymnasiumEnv(gymnasium.Env):
def __init__(self, max_episode_steps):
self.spec = GymnasiumEnvSpec('dummy-v0', max_episode_steps=max_episode_steps)
self._episode_steps = 0

def reset(self):
self._episode_steps = 0
return self.observation_space.sample(), {}

def step(self, a):
next_state = self.observation_space.sample()
reward = np.random.randn()
terminated = False
if self.spec.max_episode_steps is None:
truncated = False
else:
truncated = bool(self._episode_steps < self.spec.max_episode_steps)
info = {'rnn_states': {'dummy_scope': {'dummy_state1': 1, 'dummy_state2': 2}}}
self._episode_steps += 1
return next_state, reward, terminated, truncated, info


class DummyGymnasiumAtariEnv(AbstractDummyGymnasiumEnv):
class DummyALE(object):
def __init__(self):
self._lives = 100

def lives(self):
self._lives -= 1
if self._lives < 0:
self._lives = 100
return self._lives

# seeding.np_random outputs np_random and seed
np_random = cast("RandomNumberGenerator", nnabla_rl.random.drng)

def __init__(self, done_at_random=True, max_episode_length=None):
super(DummyGymnasiumAtariEnv, self).__init__(
max_episode_steps=max_episode_length)
self.action_space = gymnasium.spaces.Discrete(4)
self.observation_space = gymnasium.spaces.Box(
low=0, high=255, shape=(84, 84, 3), dtype=np.uint8)
self.ale = DummyGymnasiumAtariEnv.DummyALE()
self._done_at_random = done_at_random
self._max_episode_length = max_episode_length
self._episode_length = None

def step(self, action):
assert self._episode_length is not None
observation = self.observation_space.sample()
self._episode_length += 1
if self._done_at_random:
done = bool(self.np_random.integers(10) == 0)
else:
done = False
if self._max_episode_length is not None:
done = (self._max_episode_length <= self._episode_length) or done
return observation, 1.0, done, {'needs_reset': False}

def reset(self):
self._episode_length = 0
return self.observation_space.sample()

def get_action_meanings(self):
return ['NOOP', 'FIRE', 'LEFT', 'RIGHT']


class DummyGymnasiumMujocoEnv(AbstractDummyGymnasiumEnv):
def __init__(self, max_episode_steps=None):
super(DummyGymnasiumMujocoEnv, self).__init__(max_episode_steps=max_episode_steps)
self.action_space = gymnasium.spaces.Box(low=0.0, high=1.0, shape=(5, ))
self.observation_space = gymnasium.spaces.Box(low=0.0, high=1.0, shape=(5, ))

def get_dataset(self):
dataset = {}
datasize = 2000
dataset['observations'] = np.stack([self.observation_space.sample() for _ in range(datasize)], axis=0)
dataset['actions'] = np.stack([self.action_space.sample() for _ in range(datasize)], axis=0)
dataset['rewards'] = np.random.randn(datasize, 1)
dataset['terminals'] = np.random.randint(2, size=(datasize, 1))
dataset['timeouts'] = np.zeros((datasize, 1))
return dataset
3 changes: 2 additions & 1 deletion nnabla_rl/environments/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,3 +20,4 @@
from nnabla_rl.environments.wrappers.atari import make_atari, wrap_deepmind # noqa
from nnabla_rl.environments.wrappers.hybrid_env import (EmbedActionWrapper, FlattenActionWrapper, # noqa
RemoveStepWrapper, ScaleActionWrapper, ScaleStateWrapper)
from nnabla_rl.environments.wrappers.gymnasium import Gymnasium2GymWrapper # noqa
14 changes: 11 additions & 3 deletions nnabla_rl/environments/wrappers/atari.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,10 +16,12 @@

import cv2
import gym
import gymnasium
import numpy as np
from gym import spaces

import nnabla_rl as rl
from nnabla_rl.environments.wrappers.gymnasium import Gymnasium2GymWrapper
from nnabla_rl.external.atari_wrappers import (ClipRewardEnv, EpisodicLifeEnv, FireResetEnv, MaxAndSkipEnv,
NoopResetEnv, ScaledFloatFrame)

Expand Down Expand Up @@ -97,8 +99,14 @@ def __array__(self, dtype=None):
return out


def make_atari(env_id, max_frames_per_episode=None):
env = gym.make(env_id)
def make_atari(env_id, max_frames_per_episode=None, use_gymnasium=False):
if use_gymnasium:
env = gymnasium.make(env_id)
env = Gymnasium2GymWrapper(env)
# gymnasium env is not wrapped TimeLimit wrapper
env = gym.wrappers.TimeLimit(env, max_episode_steps=env.spec.kwargs["max_num_frames_per_episode"])
else:
env = gym.make(env_id)
if max_frames_per_episode is not None:
env = env.unwrapped
env = gym.wrappers.TimeLimit(env, max_episode_steps=max_frames_per_episode)
Expand Down
88 changes: 88 additions & 0 deletions nnabla_rl/environments/wrappers/gymnasium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import gym
from gym import spaces as gym_spaces
from gymnasium import spaces as gymnasium_spaces
from gymnasium.utils import seeding


class Gymnasium2GymWrapper(gym.Wrapper):
def __init__(self, env):
if isinstance(env, gym.Env) or isinstance(env, gym.Wrapper):
raise ValueError("'env' should not be an instance of 'gym.Env' and 'gym.Wrapper'")

super().__init__(env)

# observation space
if isinstance(env.observation_space, gymnasium_spaces.Tuple):
self.observation_space = gym_spaces.Tuple(
[self._translate_space(observation_space)
for observation_space in env.observation_space]
)
elif isinstance(env.observation_space, gymnasium_spaces.Dict):
self.observation_space = gym_spaces.Dict(
{key: self._translate_space(observation_space)
for key, observation_space in env.observation_space.items()}
)
else:
self.observation_space = self._translate_space(env.observation_space)

# action space
if isinstance(env.action_space, gymnasium_spaces.Tuple):
self.action_space = gym_spaces.Tuple(
[self._translate_space(action_space)
for action_space in env.action_space]
)
elif isinstance(env.action_space, gymnasium_spaces.Dict):
self.action_space = gym_spaces.Dict(
{key: self._translate_space(action_space)
for key, action_space in env.action_space.items()}
)
else:
self.action_space = self._translate_space(env.action_space)

def reset(self):
obs, _ = self.env.reset()
return obs

def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
done = (terminated or truncated)
info.update({"TimeLimit.truncated": truncated})
return obs, reward, done, info

def seed(self, seed: Optional[int] = None):
np_random, seed = seeding.np_random(seed)
self.env.np_random = np_random # type: ignore
return [seed]

@property
def unwrapped(self):
return self

def _translate_space(self, space):
if isinstance(space, gymnasium_spaces.Box):
return gym_spaces.Box(
low=space.low,
high=space.high,
shape=space.shape,
dtype=space.dtype
)
elif isinstance(space, gymnasium_spaces.Discrete):
return gym_spaces.Discrete(n=int(space.n))
else:
raise NotImplementedError
23 changes: 17 additions & 6 deletions nnabla_rl/utils/reproductions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,12 +16,14 @@
import random as py_random

import gym
import gymnasium
import numpy as np

import nnabla as nn
import nnabla_rl as rl
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.environments.wrappers import NumpyFloat32Env, ScreenRenderEnv, make_atari, wrap_deepmind
from nnabla_rl.environments.wrappers import (Gymnasium2GymWrapper, NumpyFloat32Env, ScreenRenderEnv, make_atari,
wrap_deepmind)
from nnabla_rl.logger import logger


Expand Down Expand Up @@ -53,11 +55,14 @@ def build_atari_env(id_or_env,
print_info=True,
max_frames_per_episode=None,
frame_stack=True,
flicker_probability=0.0):
flicker_probability=0.0,
use_gymnasium=False):
if isinstance(id_or_env, gym.Env):
env = id_or_env
elif isinstance(id_or_env, gymnasium.Env):
env = Gymnasium2GymWrapper(id_or_env)
else:
env = make_atari(id_or_env, max_frames_per_episode=max_frames_per_episode)
env = make_atari(id_or_env, max_frames_per_episode=max_frames_per_episode, use_gymnasium=use_gymnasium)
if print_info:
print_env_info(env)

Expand All @@ -75,7 +80,7 @@ def build_atari_env(id_or_env,
return env


def build_mujoco_env(id_or_env, test=False, seed=None, render=False, print_info=True):
def build_mujoco_env(id_or_env, test=False, seed=None, render=False, print_info=True, use_gymnasium=False):
try:
# Add pybullet env
import pybullet_envs # noqa
Expand All @@ -91,8 +96,14 @@ def build_mujoco_env(id_or_env, test=False, seed=None, render=False, print_info=

if isinstance(id_or_env, gym.Env):
env = id_or_env
elif isinstance(id_or_env, gymnasium.Env):
env = Gymnasium2GymWrapper(id_or_env)
else:
env = gym.make(id_or_env)
if use_gymnasium:
env = gymnasium.make(id_or_env)
env = Gymnasium2GymWrapper(env)
else:
env = gym.make(id_or_env)

if print_info:
print_env_info(env)
Expand Down
10 changes: 6 additions & 4 deletions reproductions/algorithms/atari/a2c/a2c_reproduction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,9 +28,9 @@

def run_training(args):
set_global_seed(args.seed)
train_env = build_atari_env(args.env, seed=args.seed)
train_env = build_atari_env(args.env, seed=args.seed, use_gymnasium=args.use_gymnasium)
eval_env = build_atari_env(
args.env, test=True, seed=args.seed + 100, render=args.render)
args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium)

iteration_num_hook = H.IterationNumHook(timing=100)

Expand All @@ -56,7 +56,8 @@ def run_training(args):
def run_showcase(args):
if args.snapshot_dir is None:
raise ValueError('Please specify the snapshot dir for showcasing')
eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=False)
eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=False,
use_gymnasium=args.use_gymnasium)
config = A.A2CConfig(gpu_id=args.gpu)
a2c = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config})
if not isinstance(a2c, A.A2C):
Expand All @@ -83,6 +84,7 @@ def main():
parser.add_argument('--save_timing', type=int, default=250000)
parser.add_argument('--eval_timing', type=int, default=250000)
parser.add_argument('--showcase_runs', type=int, default=10)
parser.add_argument('--use-gymnasium', action='store_true')

args = parser.parse_args()

Expand Down
Loading

0 comments on commit b98c747

Please sign in to comment.