-
Notifications
You must be signed in to change notification settings - Fork 387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(lxy): modify ppof rewardclip and add atari config #589
Changes from 3 commits
22ed628
b7de3f5
2ca2234
0e50b47
cc220c8
47bb5a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -77,6 +77,39 @@ def get_instance_config(env: str) -> EasyDict: | |
critic_head_hidden_size=256, | ||
actor_head_hidden_size=256, | ||
) | ||
elif env == 'qbert': | ||
cfg.n_sample = 1024 | ||
cfg.batch_size = 128 | ||
cfg.epoch_per_collect = 10 | ||
cfg.learning_rate = 0.0001 | ||
cfg.model = dict( | ||
encoder_hidden_size_list=[32, 64, 64, 128], | ||
actor_head_hidden_size=128, | ||
critic_head_hidden_size=128, | ||
critic_head_layer_num=2, | ||
) | ||
elif env == 'kangaroo': | ||
cfg.n_sample = 1024 | ||
cfg.batch_size = 128 | ||
cfg.epoch_per_collect = 10 | ||
cfg.learning_rate = 0.0001 | ||
cfg.model = dict( | ||
encoder_hidden_size_list=[32, 64, 64, 128], | ||
actor_head_hidden_size=128, | ||
critic_head_hidden_size=128, | ||
critic_head_layer_num=2, | ||
) | ||
elif env == 'bowling': | ||
cfg.n_sample = 1024 | ||
cfg.batch_size = 128 | ||
cfg.epoch_per_collect = 10 | ||
cfg.learning_rate = 0.0001 | ||
cfg.model = dict( | ||
encoder_hidden_size_list=[32, 64, 64, 128], | ||
actor_head_hidden_size=128, | ||
critic_head_hidden_size=128, | ||
critic_head_layer_num=2, | ||
) | ||
else: | ||
raise KeyError("not supported env type: {}".format(env)) | ||
return cfg | ||
|
@@ -152,6 +185,36 @@ def get_instance_env(env: str) -> BaseEnv: | |
}, | ||
seed_api=False, | ||
) | ||
elif env == 'qbert': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unify these code and remove unused code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename to |
||
from dizoo.atari.envs.atari_env import AtariEnv | ||
cfg = EasyDict({ | ||
'env_id': 'QbertNoFrameskip-v4', | ||
'env_wrapper': 'atari_default', | ||
}) | ||
ding_env_atari = DingEnvWrapper(gym.make('QbertNoFrameskip-v4'), cfg=cfg) | ||
#ding_env_atari.enable_save_replay('atari_log/') | ||
obs = ding_env_atari.reset() | ||
return ding_env_atari | ||
elif env == 'kangaroo': | ||
from dizoo.atari.envs.atari_env import AtariEnv | ||
cfg = EasyDict({ | ||
'env_id': 'KangarooNoFrameskip-v4', | ||
'env_wrapper': 'atari_default', | ||
}) | ||
ding_env_atari = DingEnvWrapper(gym.make('KangarooNoFrameskip-v4'), cfg=cfg) | ||
#ding_env_atari.enable_save_replay('atari_log/') | ||
obs = ding_env_atari.reset() | ||
return ding_env_atari | ||
elif env == 'bowling': | ||
from dizoo.atari.envs.atari_env import AtariEnv | ||
cfg = EasyDict({ | ||
'env_id': 'BowlingNoFrameskip-v4', | ||
'env_wrapper': 'atari_default', | ||
}) | ||
ding_env_atari = DingEnvWrapper(gym.make('BowlingNoFrameskip-v4'), cfg=cfg) | ||
#ding_env_atari.enable_save_replay('atari_log/') | ||
obs = ding_env_atari.reset() | ||
return ding_env_atari | ||
else: | ||
raise KeyError("not supported env type: {}".format(env)) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from typing import Optional, Union | ||
from ditk import logging | ||
from easydict import EasyDict | ||
from functools import partial | ||
import os | ||
import gym | ||
import torch | ||
|
@@ -30,6 +31,10 @@ class PPOF: | |
'mario', | ||
'di_sheep', | ||
'procgen_bigfish', | ||
# atari | ||
'qbert', | ||
'kangaroo', | ||
'bowling', | ||
] | ||
|
||
def __init__( | ||
|
@@ -86,7 +91,7 @@ def train( | |
logging.debug(self.policy._model) | ||
# define env and policy | ||
collector_env = self._setup_env_manager(collector_env_num, context, debug) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug) | ||
evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator') | ||
|
||
with task.start(ctx=OnlineRLContext()): | ||
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env)) | ||
|
@@ -168,7 +173,7 @@ def batch_evaluate( | |
if debug: | ||
logging.getLogger().setLevel(logging.DEBUG) | ||
# define env and policy | ||
env = self._setup_env_manager(env_num, context, debug) | ||
env = self._setup_env_manager(env_num, context, debug, 'evaluator') | ||
if ckpt_path is None: | ||
ckpt_path = os.path.join(self.exp_name, 'ckpt/eval.pth.tar') | ||
state_dict = torch.load(ckpt_path, map_location='cpu') | ||
|
@@ -179,7 +184,13 @@ def batch_evaluate( | |
task.use(interaction_evaluator_ttorch(self.seed, self.policy, env, n_evaluator_episode)) | ||
task.run(max_step=1) | ||
|
||
def _setup_env_manager(self, env_num: int, context: Optional[str] = None, debug: bool = False) -> BaseEnvManagerV2: | ||
def _setup_env_manager( | ||
self, | ||
env_num: int, | ||
context: Optional[str] = None, | ||
debug: bool = False, | ||
caller: str = 'collector' | ||
) -> BaseEnvManagerV2: | ||
PaParaZz1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if debug: | ||
env_cls = BaseEnvManagerV2 | ||
manager_cfg = env_cls.default_config() | ||
|
@@ -188,4 +199,4 @@ def _setup_env_manager(self, env_num: int, context: Optional[str] = None, debug: | |
manager_cfg = env_cls.default_config() | ||
if context is not None: | ||
manager_cfg.context = context | ||
return env_cls([self.env.clone for _ in range(env_num)], manager_cfg) | ||
return env_cls([partial(self.env.clone, caller) for _ in range(env_num)], manager_cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qbert can use the same config as kangaroo and bowling