Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

env(zjow): add env gym_pybullet_drones #526

Merged
merged 16 commits into from
Nov 4, 2022
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 28 |[mario](https://github.com/Kautenja/gym-super-mario-bros) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/mario/mario.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/mario) <br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/gym_super_mario_bros_zh.html) |
| 29 |[dmc2gym](https://github.com/denisyarats/dmc2gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/dmc2gym/dmc2gym_cheetah.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/dmc2gym)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/dmc2gym.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/dmc2gym_zh.html) |
| 30 |[evogym](https://github.com/EvolutionGym/evogym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/evogym/evogym.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/evogym/envs)<br>环境指南 |
| 31 |[gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/gym-pybullet-drones/gym-pybullet-drones.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_pybullet_drones/envs)<br>环境指南 |

![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space

Expand Down
16 changes: 15 additions & 1 deletion ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def serial_pipeline_onpolicy(
collect_kwargs = commander.step()
# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data by default config n_sample/n_episode
Expand All @@ -98,4 +98,18 @@ def serial_pipeline_onpolicy(

# Learner's after_run hook.
learner.call_hook('after_run')
import time
import pickle
import numpy as np
with open(os.path.join(cfg.exp_name, 'result.pkl'), 'wb') as f:
eval_value_raw = [d['final_eval_reward'] for d in eval_info]
final_data = {
'stop': stop,
'env_step': collector.envstep,
'train_iter': learner.train_iter,
'eval_value': np.mean(eval_value_raw),
'eval_value_raw': eval_value_raw,
'finish_time': time.ctime(),
}
pickle.dump(final_data, f)
return policy
Empty file.
61 changes: 61 additions & 0 deletions dizoo/gym_pybullet_drones/config/flythrugate_onppo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from easydict import EasyDict

flythrugate_ppo_config = dict(
exp_name='flythrugate_ppo_seed0',
env=dict(
manager=dict(shared_memory=False, reset_inplace=True),
env_id='flythrugate-aviary-v0',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=8,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=0,
action_type="VEL",
),
policy=dict(
cuda=True,
recompute_adv=True,
# load_path="./flythrugate_ppo_seed0/ckpt/ckpt_best.pth.tar",
model=dict(
obs_shape=12,
action_shape=4,
action_space='continuous',
),
action_space='continuous',
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
),
collect=dict(
n_sample=2048,
gae_lambda=0.97,
),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
flythrugate_ppo_config = EasyDict(flythrugate_ppo_config)
main_config = flythrugate_ppo_config

flythrugate_ppo_create_config = dict(
env=dict(
type='gym_pybullet_drones',
import_names=['dizoo.gym_pybullet_drones.envs.gym_pybullet_drones_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo', ),
)
flythrugate_ppo_create_config = EasyDict(flythrugate_ppo_create_config)
create_config = flythrugate_ppo_create_config

if __name__ == "__main__":
# or you can enter `ding -m serial_onpolicy -c flythrugate_ppo_config.py -s 0 --env-step 1e7`
from ding.entry import serial_pipeline_onpolicy
serial_pipeline_onpolicy((main_config, create_config), seed=0)
61 changes: 61 additions & 0 deletions dizoo/gym_pybullet_drones/config/takeoffaviary_onppo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from easydict import EasyDict

takeoffaviary_ppo_config = dict(
exp_name='takeoffaviary_ppo_seed0',
env=dict(
manager=dict(shared_memory=False, reset_inplace=True),
env_id='takeoff-aviary-v0',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=8,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=0,
action_type="VEL",
),
policy=dict(
cuda=True,
recompute_adv=True,
# load_path="./takeoffaviary_ppo_seed0/ckpt/ckpt_best.pth.tar",
model=dict(
obs_shape=12,
action_shape=4,
action_space='continuous',
),
action_space='continuous',
learn=dict(
epoch_per_collect=10, #reduce
batch_size=64,
learning_rate=3e-4, #tune; pytorch lr scheduler
value_weight=0.5,
entropy_weight=0.0, #0.001
clip_ratio=0.2, #0.1
adv_norm=True,
value_norm=True,
),
collect=dict(
n_sample=2048,
gae_lambda=0.97,
),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
takeoffaviary_ppo_config = EasyDict(takeoffaviary_ppo_config)
main_config = takeoffaviary_ppo_config

takeoffaviary_ppo_create_config = dict(
env=dict(
type='gym_pybullet_drones',
import_names=['dizoo.gym_pybullet_drones.envs.gym_pybullet_drones_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo', ),
)
takeoffaviary_ppo_create_config = EasyDict(takeoffaviary_ppo_create_config)
create_config = takeoffaviary_ppo_create_config

if __name__ == "__main__":
# or you can enter `ding -m serial_onpolicy -c takeoffaviary_ppo_config.py -s 0 --env-step 1e7`
from ding.entry import serial_pipeline_onpolicy
serial_pipeline_onpolicy((main_config, create_config), seed=0)
55 changes: 55 additions & 0 deletions dizoo/gym_pybullet_drones/entry/flythrugate_onppo_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import gym
import torch
from tensorboardX import SummaryWriter
from easydict import EasyDict

from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
from ding.utils import set_pkg_seed

from dizoo.gym_pybullet_drones.envs.gym_pybullet_drones_env import GymPybulletDronesEnv
from dizoo.gym_pybullet_drones.config.flythrugate_onppo_config import flythrugate_ppo_config


def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg,
BaseEnvManager,
PPOPolicy,
BaseLearner,
SampleSerialCollector,
InteractionSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num

info = cfg.env.manager

cfg.env['record'] = True
cfg.env['gui'] = True
cfg.env['print_debug_info'] = True
cfg.env['plot_observation'] = True

evaluator_env = BaseEnvManager(
env_fn=[lambda: GymPybulletDronesEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)

evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))

tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
evaluator.eval()


if __name__ == "__main__":
main(flythrugate_ppo_config)
53 changes: 53 additions & 0 deletions dizoo/gym_pybullet_drones/entry/takeoffaviary_onppo_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
import gym
import torch
from tensorboardX import SummaryWriter
from easydict import EasyDict

from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
from ding.utils import set_pkg_seed

from dizoo.gym_pybullet_drones.envs.gym_pybullet_drones_env import GymPybulletDronesEnv
from dizoo.gym_pybullet_drones.config.takeoffaviary_onppo_config import takeoffaviary_ppo_config


def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg,
BaseEnvManager,
PPOPolicy,
BaseLearner,
SampleSerialCollector,
InteractionSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num

cfg.env['record'] = True
cfg.env['gui'] = True
cfg.env['print_debug_info'] = True
cfg.env['plot_observation'] = True

evaluator_env = BaseEnvManager(
env_fn=[lambda: GymPybulletDronesEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)

evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))

tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
evaluator.eval()


if __name__ == "__main__":
main(takeoffaviary_ppo_config)
Empty file.
Loading