Skip to content

Commit

Permalink
Merge pull request #67 from ChildTang/main
Browse files Browse the repository at this point in the history
Add Retro Environment
  • Loading branch information
huangshiyu13 authored May 16, 2023
2 parents fe4b32a + 9efabfc commit dff448d
Show file tree
Hide file tree
Showing 13 changed files with 1,228 additions and 9 deletions.
51 changes: 51 additions & 0 deletions examples/retro/retro_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
""""""
import numpy as np

from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent


def train():
# 创建环境,若需要并行多个环境,需要设置参数asynchronous为True;若需要设定关卡,可以设定state参数,该参数与具体游戏有关
env = make("Airstriker-Genesis", state="Level1", env_num=2, asynchronous=True)
# 创建网络
net = Net(env, device="cuda")
# 初始化训练器
agent = Agent(net)
# 开始训练
agent.train(total_time_steps=20000)
# 关闭环境
env.close()
return agent


def game_test(agent):
# 开始测试环境
env = make(
"Airstriker-Genesis",
state="Level1",
render_mode="group_human",
env_num=4,
asynchronous=True,
)
agent.set_env(env)
obs, info = env.reset()
done = False
step = 0
while True:
# 智能体根据 observation 预测下一个动作
action, _ = agent.act(obs, deterministic=True)
obs, r, done, info = env.step(action)
step += 1
print(f"{step}: reward:{np.mean(r)}")

if any(done):
env.reset()

env.close()


if __name__ == "__main__":
agent = train()
game_test(agent)
5 changes: 4 additions & 1 deletion openrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@
3,
8,
0,
], f"OpenRL requires Python 3.8 or newer, but your Python is {platform.python_version()}"
], (
"OpenRL requires Python 3.8 or newer, but your Python is"
f" {platform.python_version()}"
)
4 changes: 3 additions & 1 deletion openrl/drivers/onpolicy_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(
client=None,
logger: Optional[Logger] = None,
) -> None:
super(OnPolicyDriver, self).__init__(config, trainer, buffer, rank, world_size, client, logger)
super(OnPolicyDriver, self).__init__(
config, trainer, buffer, rank, world_size, client, logger
)

def _inner_loop(
self,
Expand Down
2 changes: 1 addition & 1 deletion openrl/drivers/rl_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# limitations under the License.

""""""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

from abc import ABC, abstractmethod
from openrl.drivers.base_driver import BaseDriver
from openrl.utils.logger import Logger

Expand Down
Loading

0 comments on commit dff448d

Please sign in to comment.