Skip to content

Commit

Permalink
Merge pull request #202 from huangshiyu13/main
Browse files Browse the repository at this point in the history
update
  • Loading branch information
huangshiyu13 authored Aug 12, 2023
2 parents 9623d5c + 1f8c3ef commit 6300cc6
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 69 deletions.
1 change: 1 addition & 0 deletions Gallery.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Users are also welcome to contribute their own training examples and demos to th
| [Chat Bot](https://openrl-docs.readthedocs.io/en/latest/quick_start/train_nlp.html)<br> <img width="300px" height="auto" src="./docs/images/chat.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![NLP](https://img.shields.io/badge/-NLP-green) ![Transformer](https://img.shields.io/badge/-Transformer-blue) | [code](./examples/nlp/) |
| [Atari Pong](https://gymnasium.farama.org/environments/atari/pong/)<br> <img width="300px" height="auto" src="./docs/images/pong.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/atari/) |
| [PettingZoo: Tic-Tac-Toe](https://pettingzoo.farama.org/environments/classic/tictactoe/)<br> <img width="300px" height="auto" src="./docs/images/tic-tac-toe.jpeg"> | ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/selfplay/) |
| [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)<br> <img width="300px" height="auto" src="https://shimmy.farama.org/_images/dm_locomotion.png"> | ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/dm_control/) |
| [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)<br> <img width="300px" height="auto" src="https://user-images.githubusercontent.com/34286328/171454189-6afafbff-bb61-4aac-b518-24646007cb9f.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/isaac/) |
| [GridWorld](./examples/gridworld/)<br> <img width="300px" height="auto" src="./docs/images/gridworld.jpg"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/gridworld/) |
| [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)<br> <img width="300px" height="auto" src="https://user-images.githubusercontent.com/2184469/40948820-3d15e5c2-6830-11e8-81d4-ecfaffee0a14.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/super_mario/) |
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Environments currently supported by OpenRL (for more details, please refer to [G
- [Atari](https://gymnasium.farama.org/environments/atari/)
- [StarCraft II](https://github.com/oxwhirl/smac)
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
- [GridWorld](./examples/gridworld/)
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
- [Gym Retro](https://github.com/openai/retro)
Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ OpenRL目前支持的环境(更多详情请参考 [Gallery](Gallery.md)):
- [Atari](https://gymnasium.farama.org/environments/atari/)
- [StarCraft II](https://github.com/oxwhirl/smac)
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
- [GridWorld](./examples/gridworld/)
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
- [Gym Retro](https://github.com/openai/retro)
Expand Down
5 changes: 2 additions & 3 deletions examples/behavior_cloning/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ def test_env():
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args()

# create environment, set environment parallelism to 9
# env = make("OfflineEnv", env_num=1, cfg=cfg, asynchronous=True)
env = make("OfflineEnv", env_num=1, cfg=cfg, asynchronous=False)
# create environment
env = make("OfflineEnv", env_num=1, cfg=cfg, asynchronous=True)

for ep_index in range(10):
done = False
Expand Down
9 changes: 9 additions & 0 deletions examples/dm_control/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
## Installation
```bash
pip install shimmy[dm-control]
```

## Usage
```bash
python train_ppo.py
```
File renamed without changes.
45 changes: 12 additions & 33 deletions examples/dmc/train_ppo.py → examples/dm_control/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,7 @@
from openrl.envs.wrappers.extra_wrappers import GIFWrapper
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent


class FrameSkip(BaseWrapper):
def __init__(self, env, num_frames: int = 8):
super().__init__(env)
self.num_frames = num_frames

def step(self, action):
num_skips = self.num_frames
total_reward = 0.0

for x in range(num_skips):
obs, rew, term, trunc, info = super().step(action)
total_reward += rew
if term or trunc:
break

return obs, total_reward, term, trunc, info

from openrl.envs.wrappers.extra_wrappers import FrameSkip

env_name = "dm_control/cartpole-balance-v0"
# env_name = "dm_control/walker-walk-v0"
Expand All @@ -36,7 +18,7 @@ def train():
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])

# create environment, set environment parallelism to 9
# create environment, set environment parallelism to 64
env = make(
env_name,
env_num=64,
Expand All @@ -50,35 +32,30 @@ def train():
agent = Agent(
net,
)
# start training, set total number of training steps to 20000
# start training, set total number of training steps to 100000
agent.train(total_time_steps=100000)
agent.save("./ppo_agent")
env.close()
return agent





def evaluation():
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
# begin to test
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
render_mode = "group_human"
# Create an environment for testing and set the number of environments to interact with to 4. Set rendering mode to group_rgb_array.
render_mode = "group_rgb_array"
env = make(
env_name,
render_mode=render_mode,
env_num=4,
asynchronous=True,
env_wrappers=[FrameSkip,FlattenObservation],
cfg=cfg
env_wrappers=[FrameSkip, FlattenObservation],
cfg=cfg,
)
# Wrap the environment with GIFWrapper to record the GIF, and set the frame rate to 5.
env = GIFWrapper(env, gif_path="./new.gif", fps=5)



net = Net(env, cfg=cfg, device="cuda")
# initialize the trainer
agent = Agent(
Expand All @@ -103,8 +80,10 @@ def evaluation():
total_reward += np.mean(r)
if step % 50 == 0:
print(f"{step}: reward:{np.mean(r)}")
print("total step:", step, total_reward)
print("total step:", step, "total reward:", total_reward)
env.close()

train()
evaluation()

if __name__ == "__main__":
train()
evaluation()
2 changes: 1 addition & 1 deletion openrl/envs/dmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def make_dmc_envs(
render_mode: Optional[Union[str, List[str]]] = None,
**kwargs,
):
from openrl.envs.wrappers import ( # AutoReset,; DictWrapper,
from openrl.envs.wrappers import (
RemoveTruncated,
Single2MultiAgentWrapper,
)
Expand Down
32 changes: 0 additions & 32 deletions openrl/envs/dmc/dmc_env.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,13 @@
from typing import Any, Optional

import dmc2gym
import gymnasium as gym
import numpy as np

# class DmcEnv:
# def __init__(self):
# env = dmc2gym.make(
# domain_name='walker',
# task_name='walk',
# seed=42,
# visualize_reward=False,
# from_pixels='features',
# height=224,
# width=224,
# frame_skip=2
# )
# # self.observation_space = spaces.Box(
# # low=np.array([0, 0, 0, 0]),
# # high=np.array([self.nrow - 1, self.ncol - 1, self.nrow - 1, self.ncol - 1]),
# # dtype=int,
# # ) # current position and target position
# # self.action_space = spaces.Discrete(
# # 5
# # )


def make(
id: str,
render_mode: Optional[str] = None,
**kwargs: Any,
):
env = gym.make(id, render_mode=render_mode)
# env = dmc2gym.make(
# domain_name='walker',
# task_name='walk',
# seed=42,
# visualize_reward=False,
# from_pixels='features',
# height=224,
# width=224,
# frame_skip=2
# )
return env
18 changes: 18 additions & 0 deletions openrl/envs/wrappers/extra_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@
from openrl.envs.wrappers.flatten import flatten


class FrameSkip(BaseWrapper):
def __init__(self, env, num_frames: int = 8):
super().__init__(env)
self.num_frames = num_frames

def step(self, action):
num_skips = self.num_frames
total_reward = 0.0

for x in range(num_skips):
obs, rew, term, trunc, info = super().step(action)
total_reward += rew
if term or trunc:
break

return obs, total_reward, term, trunc, info


class RemoveTruncated(StepAPICompatibility, BaseWrapper):
def __init__(
self,
Expand Down

0 comments on commit 6300cc6

Please sign in to comment.