Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mujoco #88

Merged
merged 4 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions Gallery.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ Users are also welcome to contribute their own training examples and demos to th

<div align="center">

| Algorithm | Tags | Refs |
|:-----------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------:|:-------------------------------:|
| [PPO](https://arxiv.org/abs/1707.06347) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) |
| [MAPPO](https://arxiv.org/abs/2103.01955) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
| [JRPO](https://arxiv.org/abs/2302.07515) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
| [MAT](https://arxiv.org/abs/2205.14953) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
| Algorithm | Tags | Refs |
|:-------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------:|:-------------------------------:|
| [PPO](https://arxiv.org/abs/1707.06347) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) |
| [PPO-continuous](https://arxiv.org/abs/1707.06347) | ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/mujoco/) |
| [Dual-clip PPO](https://arxiv.org/abs/1912.09729) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) |
| [MAPPO](https://arxiv.org/abs/2103.01955) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
| [JRPO](https://arxiv.org/abs/2302.07515) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
| [MAT](https://arxiv.org/abs/2205.14953) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
</div>

## Demo List
Expand All @@ -38,8 +40,9 @@ Users are also welcome to contribute their own training examples and demos to th

| Environment/Demo | Tags | Refs |
|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------:|:-------------------------------:|
| [CartPole](https://gymnasium.farama.org/environments/classic_control/cart_pole/)<br> <img width="300px" height="auto" src="./docs/images/cartpole_trained.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) |
| [MuJoCo](https://github.com/deepmind/mujoco)<br> <img width="300px" height="auto" src="./docs/images/mujoco.png"> | ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/mujoco/) |
| [CartPole](https://gymnasium.farama.org/environments/classic_control/cart_pole/)<br> <img width="300px" height="auto" src="./docs/images/cartpole.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) |
| [MPE: Simple Spread](https://pettingzoo.farama.org/environments/mpe/simple_spread/)<br> <img width="300px" height="auto" src="./docs/images/simple_spread_trained.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
| [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)<br> <img width="300px" height="auto" src="https://user-images.githubusercontent.com/2184469/40948820-3d15e5c2-6830-11e8-81d4-ecfaffee0a14.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/super_mario/) |
| [Gym Retro](https://github.com/openai/retro)<br> <img width="300px" height="auto" src="./docs/images/gym-retro.webp"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/retro/) |
| [Gym Retro](https://github.com/openai/retro)<br> <img width="300px" height="auto" src="./docs/images/gym-retro.webp"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/retro/) |
</div>
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ Currently, the features supported by OpenRL include:

Algorithms currently supported by OpenRL (for more details, please refer to [Gallery](./Gallery.md)):
- [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347)
- [Dual-clip PPO](https://arxiv.org/abs/1912.09729)
- [Multi-agent PPO (MAPPO)](https://arxiv.org/abs/2103.01955)
- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515)
- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953)

Environments currently supported by OpenRL (for more details, please refer to [Gallery](./Gallery.md)):
- [Gymnasium](https://gymnasium.farama.org/)
- [MuJoCo](https://github.com/deepmind/mujoco)
- [MPE](https://github.com/openai/multiagent-particle-envs)
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
- [Gym Retro](https://github.com/openai/retro)
Expand Down
2 changes: 2 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ OpenRL是一个开源的通用强化学习研究框架,支持单智能体、

OpenRL目前支持的算法(更多详情请参考 [Gallery](Gallery.md)):
- [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347)
- [Dual-clip PPO](https://arxiv.org/abs/1912.09729)
- [Multi-agent PPO (MAPPO)](https://arxiv.org/abs/2103.01955)
- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515)
- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953)

OpenRL目前支持的环境(更多详情请参考 [Gallery](Gallery.md)):
- [Gymnasium](https://gymnasium.farama.org/)
- [MuJoCo](https://github.com/deepmind/mujoco)
- [MPE](https://github.com/openai/multiagent-particle-envs)
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
- [Gym Retro](https://github.com/openai/retro)
Expand Down
Binary file added docs/images/cartpole.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/images/cartpole_trained.gif
Binary file not shown.
Binary file added docs/images/mujoco.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions examples/cartpole/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,11 @@ Users can train CartPole via:

```shell
python train_ppo.py
```


To train with [Dual-clip PPO](https://arxiv.org/abs/1912.09729):

```shell
python train_ppo.py --config dual_clip_ppo.yaml
```
2 changes: 2 additions & 0 deletions examples/cartpole/dual_clip_ppo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
dual_clip_ppo: true
dual_clip_coeff: 3.0
11 changes: 9 additions & 2 deletions examples/cartpole/train_ppo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""""""
import numpy as np

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
Expand All @@ -10,7 +11,12 @@ def train():
# create environment, set environment parallelism to 9
env = make("CartPole-v1", env_num=9)
# create the neural network
net = Net(env)
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args()
net = Net(
env,
cfg=cfg,
)
# initialize the trainer
agent = Agent(net)
# start training, set total number of training steps to 20000
Expand All @@ -34,7 +40,8 @@ def evaluation(agent):
action, _ = agent.act(obs, deterministic=True)
obs, r, done, info = env.step(action)
step += 1
print(f"{step}: reward:{np.mean(r)}")
if step % 50 == 0:
print(f"{step}: reward:{np.mean(r)}")
env.close()


Expand Down
9 changes: 9 additions & 0 deletions examples/mujoco/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
## Installation

`pip install mujoco`

## Usage

```shell
python train_ppo.py
```
65 changes: 65 additions & 0 deletions examples/mujoco/train_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import numpy as np

from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent


def train():
# create environment, set environment parallelism to 9
env = make("InvertedPendulum-v4", env_num=9)
# create the neural network
net = Net(env)
# initialize the trainer
agent = Agent(net)
# start training, set total number of training steps to 20000
agent.train(total_time_steps=30000)
env.close()
return agent


def evaluation(agent):
# begin to test
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
env = make("InvertedPendulum-v4", render_mode=None, env_num=9, asynchronous=False)

# The trained agent sets up the interactive environment it needs.
agent.set_env(env)
# Initialize the environment and get initial observations and environmental information.
obs, info = env.reset()

done = False
step = 0
totoal_reward = 0
while not np.any(done):
# Based on environmental observation input, predict next action.
action, _ = agent.act(obs, deterministic=True)
obs, r, done, info = env.step(action)
step += 1
if step % 100 == 0:
print(f"{step}: reward:{np.mean(r)}")
totoal_reward += np.mean(r)
env.close()
print(f"total reward: {totoal_reward}")


if __name__ == "__main__":
agent = train()
evaluation(agent)
1 change: 0 additions & 1 deletion openrl/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,6 @@ def create_config_parser():
)
parser.add_argument(
"--dual_clip_ppo",
action="store_true",
default=False,
help="by default False, use dual-clip ppo.",
)
Expand Down
4 changes: 3 additions & 1 deletion openrl/envs/vec_env/base_venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import gymnasium as gym
import numpy as np

from openrl.envs.vec_env.utils.numpy_utils import single_random_action
from openrl.envs.vec_env.utils.util import tile_images

IN_COLAB = "google.colab" in sys.modules
Expand Down Expand Up @@ -257,9 +258,10 @@ def random_action(self):
"""
Get a random action from the action space
"""

return np.array(
[
[[self.action_space.sample()] for _ in range(self.agent_num)]
[single_random_action(self.action_space) for _ in range(self.agent_num)]
for _ in range(self.parallel_env_num)
]
)
20 changes: 19 additions & 1 deletion openrl/envs/vec_env/utils/numpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"concatenate",
"create_empty_array",
"iterate_action",
"single_random_action",
]


Expand Down Expand Up @@ -53,7 +54,7 @@ def _iterate_discrete(space, actions):
@iterate_action.register(MultiDiscrete)
@iterate_action.register(MultiBinary)
def _iterate_base(space, actions):
raise NotImplementedError("Not implemented yet.")
return iter(actions)


@iterate_action.register(Tuple)
Expand Down Expand Up @@ -205,3 +206,20 @@ def _create_empty_array_dict(space, n=1, agent_num=1, fn=np.zeros):
@create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, agent_num=1, fn=np.zeros):
return None


@singledispatch
def single_random_action(space: Space) -> Union[tuple, dict, np.ndarray]:
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
)


@single_random_action.register(Discrete)
def _single_random_action_discrete(space):
return [space.sample()]


@single_random_action.register(Box)
def _single_random_action_discrete(space):
return space.sample()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_install_requires() -> list:
"imageio",
"opencv-python",
"pygame",
"mujoco",
]


Expand Down
58 changes: 58 additions & 0 deletions tests/test_examples/test_train_mujoco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

import os
import sys

import numpy as np
import pytest

from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent


@pytest.fixture(scope="module", params=[""])
def config(request):
from openrl.configs.config import create_config_parser

cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(request.param.split())
return cfg


@pytest.mark.unittest
def test_train_mujoco(config):
env = make("InvertedPendulum-v4", env_num=9)
agent = Agent(Net(env, cfg=config))
agent.train(total_time_steps=30000)

agent.set_env(env)
obs, info = env.reset()
done = False
total_reward = 0
while not np.any(done):
action, _ = agent.act(obs, deterministic=True)
obs, r, done, info = env.step(action)
total_reward += np.mean(r)
assert total_reward >= 900, "InvertedPendulum-v4 should be solved."
env.close()


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))