Skip to content

Commit

Permalink
Add envpool to openrl
Browse files Browse the repository at this point in the history
  • Loading branch information
kingjuno committed Dec 7, 2023
1 parent 9a05e6f commit dac2804
Show file tree
Hide file tree
Showing 7 changed files with 425 additions and 8 deletions.
78 changes: 78 additions & 0 deletions examples/envpool/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

# Use OpenRL to load stable-baselines's model for testing

import numpy as np
import torch

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common.ppo_net import PPONet as Net
from openrl.modules.networks.policy_value_network_sb3 import (
PolicyValueNetworkSB3 as PolicyValueNetwork,
)
from openrl.runners.common import PPOAgent as Agent


def evaluation(local_trained_file_path=None):
# begin to test

cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])

# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
render_mode = "group_human"
render_mode = None
env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True)
model_dict = {"model": PolicyValueNetwork}
net = Net(
env,
cfg=cfg,
model_dict=model_dict,
device="cuda" if torch.cuda.is_available() else "cpu",
)
# initialize the trainer
agent = Agent(
net,
)
if local_trained_file_path is not None:
agent.load(local_trained_file_path)
# The trained agent sets up the interactive environment it needs.
agent.set_env(env)
# Initialize the environment and get initial observations and environmental information.
obs, info = env.reset()
done = False

total_step = 0
total_reward = 0.0
while not np.any(done):
# Based on environmental observation input, predict next action.
action, _ = agent.act(obs, deterministic=True)
obs, r, done, info = env.step(action)
total_step += 1
total_reward += np.mean(r)
if total_step % 50 == 0:
print(f"{total_step}: reward:{np.mean(r)}")
env.close()
print("total step:", total_step)
print("total reward:", total_reward)


if __name__ == "__main__":
evaluation()
86 changes: 86 additions & 0 deletions examples/envpool/train_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import numpy as np
from test_model import evaluation

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.envs.wrappers.envpool_wrappers import VecAdapter, VecMonitor
from openrl.modules.common import PPONet as Net
from openrl.modules.common.ppo_net import PPONet as Net
from openrl.runners.common import PPOAgent as Agent


def train():
# create the neural network
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args()

# create environment, set environment parallelism to 9
env = make(
"envpool:Adventure-v5",
render_mode=None,
env_num=9,
asynchronous=False,
env_wrappers=[VecAdapter, VecMonitor],
env_type="gym",
)

net = Net(
env,
cfg=cfg,
)
# initialize the trainer
agent = Agent(net, use_wandb=False, project_name="envpool:Adventure-v5")
# start training, set total number of training steps to 20000
agent.train(total_time_steps=20000)

env.close()
return agent


def evaluation(agent):
# begin to test
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
render_mode = "group_human"
render_mode = None
env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True)
# The trained agent sets up the interactive environment it needs.
agent.set_env(env)
# Initialize the environment and get initial observations and environmental information.
obs, info = env.reset()
done = False
step = 0
total_step, total_reward = 0, 0
while not np.any(done):
# Based on environmental observation input, predict next action.
action, _ = agent.act(obs, deterministic=True)
obs, r, done, info = env.step(action)
step += 1
total_step += 1
total_reward += np.mean(r)
if step % 50 == 0:
print(f"{step}: reward:{np.mean(r)}")
env.close()
print("total step:", total_step)
print("total reward:", total_reward)


if __name__ == "__main__":
agent = train()
evaluation(agent)
24 changes: 17 additions & 7 deletions openrl/envs/common/build_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from gymnasium import Env

from openrl.envs.wrappers.base_wrapper import BaseWrapper
from openrl.envs.wrappers.envpool_wrappers import VecEnvWrapper, VecMonitor


def build_envs(
Expand Down Expand Up @@ -36,13 +37,22 @@ def _make_env() -> Env:
new_kwargs["env_num"] = env_num
if id.startswith("ALE/") or id in gym.envs.registry.keys():
new_kwargs.pop("cfg", None)

env = make(
id,
render_mode=env_render_mode,
disable_env_checker=_disable_env_checker,
**new_kwargs,
)
if "envpool" in new_kwargs:
# for now envpool doesnt support any render mode
# envpool also doesnt stores the id anywhere
new_kwargs.pop("envpool")
env = make(
id,
**new_kwargs,
)
env.unwrapped.spec.id = id
else:
env = make(
id,
render_mode=env_render_mode,
disable_env_checker=_disable_env_checker,
**new_kwargs,
)

if wrappers is not None:
if callable(wrappers):
Expand Down
14 changes: 13 additions & 1 deletion openrl/envs/common/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
""""""
from typing import Callable, Optional

import envpool
import gymnasium as gym

import openrl
Expand Down Expand Up @@ -72,7 +73,6 @@ def make(
env_fns = make_single_agent_drone_envs(
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
)

elif id.startswith("snakes_"):
from openrl.envs.snake import make_snake_envs

Expand Down Expand Up @@ -155,6 +155,18 @@ def make(
env_fns = make_PettingZoo_envs(
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
)
elif (
"envpool:" in id
and id.split(":")[-1] in envpool.registration.list_all_envs()
):
from openrl.envs.envpool import make_envpool_envs

env_fns = make_envpool_envs(
id=id.split(":")[-1],
env_num=env_num,
render_mode=convert_render_mode,
**kwargs,
)
else:
raise NotImplementedError(f"env {id} is not supported.")

Expand Down
47 changes: 47 additions & 0 deletions openrl/envs/envpool/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
from typing import List, Optional, Union

import envpool

from openrl.envs.common import build_envs


def make_envpool_envs(
id: str,
env_num: int = 1,
render_mode: Optional[Union[str, List[str]]] = None,
**kwargs,
):
assert "env_type" in kwargs
assert kwargs.get("env_type") in ["gym", "dm", "gymnasium"]
# Since render_mode is not supported, we set envpool to True
# so that we can remove render_mode keyword argument from build_envs
assert render_mode is None, "envpool does not support render_mode yet"
kwargs["envpool"] = True

env_wrappers = kwargs.pop("env_wrappers")
env_fns = build_envs(
make=envpool.make,
id=id,
env_num=env_num,
render_mode=render_mode,
wrappers=env_wrappers,
**kwargs,
)
return env_fns
Loading

0 comments on commit dac2804

Please sign in to comment.