Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retro #73

Merged
merged 20 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions examples/common/custom_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
from typing import Optional

from gymnasium import Env

from examples.retro.retro_env import retro_all_envs
from openrl.envs.vec_env import (
AsyncVectorEnv,
RewardWrapper,
SyncVectorEnv,
VecMonitorWrapper,
)
from openrl.envs.vec_env.vec_info import VecInfoFactory
from openrl.rewards import RewardFactory


def make(
id: str,
cfg=None,
env_num: int = 1,
asynchronous: bool = False,
add_monitor: bool = True,
render_mode: Optional[str] = None,
**kwargs,
) -> Env:
if render_mode in [None, "human", "rgb_array"]:
convert_render_mode = render_mode
elif render_mode in ["group_human", "group_rgb_array"]:
# will display all the envs (when render_mode == "group_human")
# or return all the envs' images (when render_mode == "group_rgb_array")
convert_render_mode = "rgb_array"
elif render_mode == "single_human":
# will only display the first env
convert_render_mode = [None] * (env_num - 1)
convert_render_mode = ["human"] + convert_render_mode
render_mode = None
elif render_mode == "single_rgb_array":
# env.render() will only return the first env's image
convert_render_mode = [None] * (env_num - 1)
convert_render_mode = ["rgb_array"] + convert_render_mode
else:
raise NotImplementedError(f"render_mode {render_mode} is not supported.")

if id in retro_all_envs:
from examples.retro.retro_env import make_retro_envs

env_fns = make_retro_envs(
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
)
else:
raise NotImplementedError(f"env {id} is not supported.")

if asynchronous:
env = AsyncVectorEnv(env_fns, render_mode=render_mode)
else:
env = SyncVectorEnv(env_fns, render_mode=render_mode)

reward_class = cfg.reward_class if cfg else None
reward_class = RewardFactory.get_reward_class(reward_class, env)

env = RewardWrapper(env, reward_class)

if add_monitor:
vec_info_class = cfg.vec_info_class if cfg else None
vec_info_class = VecInfoFactory.get_vec_info_class(vec_info_class, env)
env = VecMonitorWrapper(vec_info_class, env)

return env
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

from typing import Callable, List, Optional, Union

import gymnasium as gym
import retro
from gymnasium import Env

from examples.retro.retro_env.retro_convert import RetroWrapper
from openrl.envs.common import build_envs
from openrl.envs.retro.retro_convert import RetroWrapper

retro_all_envs = retro.data.list_games()


def make_retro_envs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,33 @@

from typing import Any, Dict, List, Optional, Union

import gymnasium as gym
import gymnasium
import numpy as np
import retro
from gymnasium import Wrapper
from retro import RetroEnv


class CustomRetroEnv(RetroEnv):
def __init__(self, game: str, **kwargs):
super(CustomRetroEnv, self).__init__(game, **kwargs)

def seed(self, seed: Optional[int] = None):
seed1 = np.random.seed(seed)

seed1 = np.random.randint(0, 2**31)
seed2 = np.random.randint(0, 2**31)

return [seed1, seed2]

def render(self, mode: Optional[str] = "human", close: Optional[bool] = False):
if close:
if self.viewer:
self.viewer.close()
return

img = self.get_screen() if self.img is None else self.img

return img


class RetroWrapper(Wrapper):
Expand All @@ -32,20 +55,20 @@ def __init__(
disable_env_checker: Optional[bool] = None,
**kwargs
):
self.env = retro.make(game=game, **kwargs)
self.env = CustomRetroEnv(game=game, **kwargs)

super().__init__(self.env)

shape = self.env.observation_space.shape
shape = (shape[2],) + shape[0:2]
self.observation_space = gym.spaces.Box(
self.observation_space = gymnasium.spaces.Box(
low=0,
high=255,
shape=shape,
dtype=self.env.observation_space.dtype,
)

self.action_space = gym.spaces.Discrete(self.env.action_space.n)
self.action_space = gymnasium.spaces.Discrete(self.env.action_space.n)

self.env_name = game

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
""""""
import numpy as np

from openrl.envs.common import make
from examples.common.custom_registration import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent

Expand All @@ -14,7 +14,7 @@ def train():
# 初始化训练器
agent = Agent(net)
# 开始训练
agent.train(total_time_steps=20000)
agent.train(total_time_steps=2000)
# 关闭环境
env.close()
return agent
Expand All @@ -41,7 +41,7 @@ def game_test(agent):
print(f"{step}: reward:{np.mean(r)}")

if any(done):
env.reset()
break

env.close()

Expand Down
2 changes: 1 addition & 1 deletion openrl/buffers/normal_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,4 @@ def naive_recurrent_generator(self, advantages, num_mini_batch):
def recurrent_generator(self, advantages, num_mini_batch, data_chunk_length):
return self.data.recurrent_generator(
advantages, num_mini_batch, data_chunk_length
)
)
2 changes: 1 addition & 1 deletion openrl/buffers/offpolicy_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ def get_buffer_size(self):
if self.data.first_insert_flag:
return self.data.step
else:
return self.buffer_size
return self.buffer_size
9 changes: 3 additions & 6 deletions openrl/buffers/offpolicy_replay_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@
import torch
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

from openrl.buffers.utils.obs_data import ObsData
from openrl.buffers.utils.util import (
get_critic_obs,
get_policy_obs,
)
from openrl.buffers.replay_data import ReplayData
from openrl.buffers.utils.obs_data import ObsData
from openrl.buffers.utils.util import get_critic_obs, get_policy_obs


class OffPolicyReplayData(ReplayData):
Expand Down Expand Up @@ -151,4 +148,4 @@ def after_update(self):
self.available_actions[0] = self.available_actions[-1].copy()

def compute_returns(self, next_value, value_normalizer=None):
pass
pass
41 changes: 25 additions & 16 deletions openrl/drivers/offpolicy_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# limitations under the License.

""""""
import random
from typing import Any, Dict, Optional

import random
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel
Expand All @@ -38,31 +38,35 @@ def __init__(
client=None,
logger: Optional[Logger] = None,
) -> None:
super(OffPolicyDriver, self).__init__(config, trainer, buffer, rank, world_size, client, logger)
super(OffPolicyDriver, self).__init__(
config, trainer, buffer, rank, world_size, client, logger
)

self.buffer_minimal_size = int(config["cfg"].buffer_size * 0.2)
self.epsilon_start = config.epsilon_start
self.epsilon_finish = config.epsilon_finish
self.epsilon_anneal_time = config.epsilon_anneal_time

def _inner_loop(
self,
self,
) -> None:
rollout_infos = self.actor_rollout()

if self.buffer.get_buffer_size() > self.buffer_minimal_size:
train_infos = self.learner_update()
self.buffer.after_update()
else:
train_infos = {'value_loss': 0,
'policy_loss': 0,
'dist_entropy': 0,
'actor_grad_norm': 0,
'critic_grad_norm': 0,
'ratio': 0}
train_infos = {
"value_loss": 0,
"policy_loss": 0,
"dist_entropy": 0,
"actor_grad_norm": 0,
"critic_grad_norm": 0,
"ratio": 0,
}

self.total_num_steps = (
(self.episode + 1) * self.episode_length * self.n_rollout_threads
(self.episode + 1) * self.episode_length * self.n_rollout_threads
)

if self.episode % self.log_interval == 0:
Expand Down Expand Up @@ -161,13 +165,13 @@ def compute_returns(self):
np.split(_t2n(next_values), self.learner_n_rollout_threads)
)
if "critic" in self.trainer.algo_module.models and isinstance(
self.trainer.algo_module.models["critic"], DistributedDataParallel
self.trainer.algo_module.models["critic"], DistributedDataParallel
):
value_normalizer = self.trainer.algo_module.models[
"critic"
].module.value_normalizer
elif "model" in self.trainer.algo_module.models and isinstance(
self.trainer.algo_module.models["model"], DistributedDataParallel
self.trainer.algo_module.models["model"], DistributedDataParallel
):
value_normalizer = self.trainer.algo_module.models["model"].value_normalizer
else:
Expand All @@ -176,8 +180,8 @@ def compute_returns(self):

@torch.no_grad()
def act(
self,
step: int,
self,
step: int,
):
self.trainer.prep_rollout()

Expand All @@ -194,7 +198,12 @@ def act(
rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads))

# todo add epsilon greedy
epsilon = self.epsilon_finish + (self.epsilon_start - self.epsilon_finish) / self.epsilon_anneal_time * step
epsilon = (
self.epsilon_finish
+ (self.epsilon_start - self.epsilon_finish)
/ self.epsilon_anneal_time
* step
)
if random.random() > epsilon:
action = q_values.argmax().item()
else:
Expand All @@ -204,4 +213,4 @@ def act(
q_values,
action,
rnn_states,
)
)
Loading