Skip to content

Commit

Permalink
Merge pull request #225 from huangshiyu13/main
Browse files Browse the repository at this point in the history
add gym_pybullet_drones env
  • Loading branch information
huangshiyu13 authored Sep 5, 2023
2 parents 0978360 + e1e1a79 commit 220919a
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 33 deletions.
66 changes: 34 additions & 32 deletions Gallery.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ Environments currently supported by OpenRL (for more details, please refer to [G
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
- [Snake](http://www.jidiai.cn/env_detail?envid=1)
- [gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones)
- [GridWorld](./examples/gridworld/)
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
- [Gym Retro](https://github.com/openai/retro)
Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ OpenRL目前支持的环境(更多详情请参考 [Gallery](Gallery.md)):
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
- [Snake](http://www.jidiai.cn/env_detail?envid=1)
- [gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones)
- [GridWorld](./examples/gridworld/)
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
- [Gym Retro](https://github.com/openai/retro)
Expand Down
Binary file added docs/images/drone.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions examples/gym_pybullet_drones/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

### Installation

- Python >= 3.10
- Fellow the installation instruction of [gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones#installation).

### Train PPO

10 changes: 10 additions & 0 deletions examples/gym_pybullet_drones/ppo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
episode_length: 500
lr: 1e-3
critic_lr: 1e-3
gamma: 0.1
ppo_epoch: 5
use_valuenorm: true
entropy_coef: 0.0
hidden_size: 128
layer_N: 4
use_recurrent_policy: true
70 changes: 70 additions & 0 deletions examples/gym_pybullet_drones/test_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import time

import gym_pybullet_drones
import gymnasium as gym
import numpy as np

from openrl.envs.common import make


def test_env():
env = gym.make("hover-aviary-v0", gui=False, record=False)
print("obs space:", env.observation_space)
print("action space:", env.action_space)
obs, info = env.reset(seed=42, options={})
totoal_step = 0
totol_reward = 0.0
while True:
obs, reward, done, truncated, info = env.step(env.action_space.sample())
totoal_step += 1
totol_reward += reward
# env.render()
# time.sleep(1)
if done:
break
print("total step:", totoal_step)
print("total reward:", totol_reward)


def test_vec_env():
env = make(
"pybullet_drones/hover-aviary-v0",
env_num=2,
gui=False,
record=False,
asynchronous=True,
)
info, obs = env.reset(seed=0)
totoal_step = 0
totol_reward = 0.0
while True:
obs, reward, done, info = env.step(env.random_action())
totoal_step += 1
totol_reward += np.mean(reward)
if np.any(done) or totoal_step > 100:
break
env.close()
print("total step:", totoal_step)
print("total reward:", totol_reward)


if __name__ == "__main__":
test_env()
# test_vec_env()
89 changes: 89 additions & 0 deletions examples/gym_pybullet_drones/train_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
import torch

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent

env_name = "pybullet_drones/hover-aviary-v0"


def train():
# create the neural network
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])

# create environment, set environment parallelism to 64
env_num = 20
# env_num = 1

env = make(
env_name,
env_num=env_num,
cfg=cfg,
asynchronous=True,
env_wrappers=[],
gui=False,
)

net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
# initialize the trainer
agent = Agent(
net,
)
# start training, set total number of training steps to 100000
agent.train(total_time_steps=1000000)

agent.save("./ppo_agent")
env.close()
return agent


def evaluation():
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
# begin to test
# Create an environment for testing and set the number of environments to interact with to 4. Set rendering mode to group_rgb_array.

env = make(
env_name,
env_num=1,
asynchronous=False,
env_wrappers=[],
cfg=cfg,
gui=False,
record=False,
)

net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
# initialize the trainer
agent = Agent(
net,
)
agent.load("./ppo_agent")

# The trained agent sets up the interactive environment it needs.
agent.set_env(env)
# Initialize the environment and get initial observations and environmental information.
obs, info = env.reset()
done = False
step = 0
total_reward = 0.0
while not np.any(done):
# Based on environmental observation input, predict next action.
action, _ = agent.act(obs, deterministic=True)
print("action:", action)
obs, r, done, info = env.step(action)
step += 1
total_reward += np.mean(r)
# if step % 50 == 0:
# print(f"{step}: reward:{np.mean(r)}")
print("total step:", step)
print("total reward:", total_reward)
env.close()


if __name__ == "__main__":
# train()
evaluation()
9 changes: 8 additions & 1 deletion openrl/envs/common/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ def make(
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
)
else:
if id.startswith("snakes_"):
if id.startswith("pybullet_drones/"):
from openrl.envs.gym_pybullet_drones import make_single_agent_drone_envs

env_fns = make_single_agent_drone_envs(
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
)

elif id.startswith("snakes_"):
from openrl.envs.snake import make_snake_envs

env_fns = make_snake_envs(
Expand Down
63 changes: 63 additions & 0 deletions openrl/envs/gym_pybullet_drones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import copy
from typing import Callable, List, Optional, Union

import gymnasium as gym
from gymnasium import Env

from openrl.envs.common import build_envs


def make_single_agent_drone_env(id: str, render_mode, disable_env_checker, **kwargs):
import gym_pybullet_drones

prefix = "pybullet_drones/"
assert id.startswith(prefix), "id must start with pybullet_drones/"
kwargs.pop("cfg")

env = gym.envs.registration.make(id[len(prefix) :], **kwargs)
return env


def make_single_agent_drone_envs(
id: str,
env_num: int = 1,
render_mode: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> List[Callable[[], Env]]:
from openrl.envs.wrappers import ( # AutoReset,; DictWrapper,
RemoveTruncated,
Single2MultiAgentWrapper,
)

env_wrappers = copy.copy(kwargs.pop("env_wrappers", []))
env_wrappers += [
Single2MultiAgentWrapper,
RemoveTruncated,
]

env_fns = build_envs(
make=make_single_agent_drone_env,
id=id,
env_num=env_num,
render_mode=render_mode,
wrappers=env_wrappers,
**kwargs,
)
return env_fns

0 comments on commit 220919a

Please sign in to comment.