Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(lxy): modify ppof rewardclip and add atari config #589

Merged
merged 6 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions ding/bonus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,39 @@ def get_instance_config(env: str) -> EasyDict:
critic_head_hidden_size=256,
actor_head_hidden_size=256,
)
elif env == 'qbert':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qbert can use the same config as kangaroo and bowling

cfg.n_sample = 1024
cfg.batch_size = 128
cfg.epoch_per_collect = 10
cfg.learning_rate = 0.0001
cfg.model = dict(
encoder_hidden_size_list=[32, 64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
critic_head_layer_num=2,
)
elif env == 'kangaroo':
cfg.n_sample = 1024
cfg.batch_size = 128
cfg.epoch_per_collect = 10
cfg.learning_rate = 0.0001
cfg.model = dict(
encoder_hidden_size_list=[32, 64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
critic_head_layer_num=2,
)
elif env == 'bowling':
cfg.n_sample = 1024
cfg.batch_size = 128
cfg.epoch_per_collect = 10
cfg.learning_rate = 0.0001
cfg.model = dict(
encoder_hidden_size_list=[32, 64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
critic_head_layer_num=2,
)
else:
raise KeyError("not supported env type: {}".format(env))
return cfg
Expand Down Expand Up @@ -152,6 +185,36 @@ def get_instance_env(env: str) -> BaseEnv:
},
seed_api=False,
)
elif env == 'qbert':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unify these code and remove unused code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to atari_qbert

from dizoo.atari.envs.atari_env import AtariEnv
cfg = EasyDict({
'env_id': 'QbertNoFrameskip-v4',
'env_wrapper': 'atari_default',
})
ding_env_atari = DingEnvWrapper(gym.make('QbertNoFrameskip-v4'), cfg=cfg)
#ding_env_atari.enable_save_replay('atari_log/')
obs = ding_env_atari.reset()
return ding_env_atari
elif env == 'kangaroo':
from dizoo.atari.envs.atari_env import AtariEnv
cfg = EasyDict({
'env_id': 'KangarooNoFrameskip-v4',
'env_wrapper': 'atari_default',
})
ding_env_atari = DingEnvWrapper(gym.make('KangarooNoFrameskip-v4'), cfg=cfg)
#ding_env_atari.enable_save_replay('atari_log/')
obs = ding_env_atari.reset()
return ding_env_atari
elif env == 'bowling':
from dizoo.atari.envs.atari_env import AtariEnv
cfg = EasyDict({
'env_id': 'BowlingNoFrameskip-v4',
'env_wrapper': 'atari_default',
})
ding_env_atari = DingEnvWrapper(gym.make('BowlingNoFrameskip-v4'), cfg=cfg)
#ding_env_atari.enable_save_replay('atari_log/')
obs = ding_env_atari.reset()
return ding_env_atari
else:
raise KeyError("not supported env type: {}".format(env))

Expand Down
19 changes: 15 additions & 4 deletions ding/bonus/ppof.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Union
from ditk import logging
from easydict import EasyDict
from functools import partial
import os
import gym
import torch
Expand Down Expand Up @@ -30,6 +31,10 @@ class PPOF:
'mario',
'di_sheep',
'procgen_bigfish',
# atari
'qbert',
'kangaroo',
'bowling',
]

def __init__(
Expand Down Expand Up @@ -86,7 +91,7 @@ def train(
logging.debug(self.policy._model)
# define env and policy
collector_env = self._setup_env_manager(collector_env_num, context, debug)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add collector here

evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug)
evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')

with task.start(ctx=OnlineRLContext()):
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
Expand Down Expand Up @@ -168,7 +173,7 @@ def batch_evaluate(
if debug:
logging.getLogger().setLevel(logging.DEBUG)
# define env and policy
env = self._setup_env_manager(env_num, context, debug)
env = self._setup_env_manager(env_num, context, debug, 'evaluator')
if ckpt_path is None:
ckpt_path = os.path.join(self.exp_name, 'ckpt/eval.pth.tar')
state_dict = torch.load(ckpt_path, map_location='cpu')
Expand All @@ -179,7 +184,13 @@ def batch_evaluate(
task.use(interaction_evaluator_ttorch(self.seed, self.policy, env, n_evaluator_episode))
task.run(max_step=1)

def _setup_env_manager(self, env_num: int, context: Optional[str] = None, debug: bool = False) -> BaseEnvManagerV2:
def _setup_env_manager(
self,
env_num: int,
context: Optional[str] = None,
debug: bool = False,
caller: str = 'collector'
) -> BaseEnvManagerV2:
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
if debug:
env_cls = BaseEnvManagerV2
manager_cfg = env_cls.default_config()
Expand All @@ -188,4 +199,4 @@ def _setup_env_manager(self, env_num: int, context: Optional[str] = None, debug:
manager_cfg = env_cls.default_config()
if context is not None:
manager_cfg.context = context
return env_cls([self.env.clone for _ in range(env_num)], manager_cfg)
return env_cls([partial(self.env.clone, caller) for _ in range(env_num)], manager_cfg)
6 changes: 4 additions & 2 deletions ding/envs/env/default_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
eval_episode_return_wrapper = EasyDict(type='eval_episode_return')


def get_default_wrappers(env_wrapper_name: str, env_id: Optional[str] = None) -> List[dict]:
def get_default_wrappers(env_wrapper_name: str, env_id: Optional[str] = None, caller: str = 'collector') -> List[dict]:
assert caller == 'collector' or 'evaluator'
if env_wrapper_name == 'mujoco_default':
return [
EasyDict(type='delay_reward', kwargs=dict(delay_reward_step=3)),
Expand All @@ -21,7 +22,8 @@ def get_default_wrappers(env_wrapper_name: str, env_id: Optional[str] = None) ->
wrapper_list.append(EasyDict(type='fire_reset'))
wrapper_list.append(EasyDict(type='warp_frame'))
wrapper_list.append(EasyDict(type='scaled_float_frame'))
wrapper_list.append(EasyDict(type='clip_reward'))
if caller == 'collector':
wrapper_list.append(EasyDict(type='clip_reward'))
wrapper_list.append(EasyDict(type='frame_stack', kwargs=dict(n_frames=4)))
wrapper_list.append(copy.deepcopy(eval_episode_return_wrapper))
return wrapper_list
Expand Down
15 changes: 8 additions & 7 deletions ding/envs/env/ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class DingEnvWrapper(BaseEnv):

def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True) -> None:
def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None:
"""
You can pass in either an env instance, or a config to create an env instance:
- An env instance: Parameter `env` must not be `None`, but should be the instance.
Expand All @@ -25,6 +25,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True)
self._raw_env = env
self._cfg = cfg
self._seed_api = seed_api # some env may disable `env.seed` api
self._caller = caller
if self._cfg is None:
self._cfg = dict()
self._cfg = EasyDict(self._cfg)
Expand All @@ -37,7 +38,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True)
if env is not None:
self._init_flag = True
self._env = env
self._wrap_env()
self._wrap_env(caller)
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._action_space.seed(0) # default seed
Expand All @@ -57,7 +58,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True)
def reset(self) -> None:
if not self._init_flag:
self._env = gym.make(self._cfg.env_id)
self._wrap_env()
self._wrap_env(self._caller)
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._reward_space = gym.spaces.Box(
Expand Down Expand Up @@ -149,11 +150,11 @@ def random_action(self) -> np.ndarray:
)
return random_action

def _wrap_env(self) -> None:
def _wrap_env(self, caller: str = 'collector') -> None:
# wrapper_cfgs: Union[str, List]
wrapper_cfgs = self._cfg.env_wrapper
if isinstance(wrapper_cfgs, str):
wrapper_cfgs = get_default_wrappers(wrapper_cfgs, self._cfg.env_id)
wrapper_cfgs = get_default_wrappers(wrapper_cfgs, self._cfg.env_id, caller)
# self._wrapper_cfgs: List[Union[Callable, Dict]]
self._wrapper_cfgs = wrapper_cfgs
for wrapper in self._wrapper_cfgs:
Expand Down Expand Up @@ -197,12 +198,12 @@ def action_space(self) -> gym.spaces.Space:
def reward_space(self) -> gym.spaces.Space:
return self._reward_space

def clone(self) -> BaseEnv:
def clone(self, caller: str = 'collector') -> BaseEnv:
try:
spec = copy.deepcopy(self._raw_env.spec)
raw_env = CloudPickleWrapper(self._raw_env)
raw_env = copy.deepcopy(raw_env).data
raw_env.__setattr__('spec', spec)
except Exception:
raw_env = self._raw_env
return DingEnvWrapper(raw_env, self._cfg, self._seed_api)
return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller)