-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
94 lines (74 loc) · 3.31 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import logging
import envpool
import custom_env
import gymnasium as gym
from gymnasium.vector import AsyncVectorEnv
from agent import PPOAgent
from utils.general import get_config
from utils.envs import (
create_mujoco_env
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument("--exp_name", type=str, default=None)
parser.add_argument("--train", action='store_true')
parser.add_argument("--eval", action='store_true')
parser.add_argument("--eval_n_episode", type=int, default=10)
parser.add_argument("--load_postfix", type=str, default=None,
help="pretrained model prefix(ex/ number of episode, 'best' or 'last') from same experiments")
parser.add_argument("--experiment_path", type=str, default=None,
help="path to pretrained model ")
parser.add_argument("--video_path", type=str, default='videos',
help="path to saving playing video ")
parser.add_argument("--not_resume", action='store_true')
parser.add_argument("--desc", type=str, default="",
help="Additional description of the executing code")
args = parser.parse_args()
return args
def make_env(env_name):
def _init():
env = gym.make(env_name, env_config={"render_mode":"rgb_array"})
return env
return _init
def main():
args = parse_args()
# Setting logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO,
handlers=[logging.StreamHandler()])
logging.info(f"Description: {args.desc}")
if args.load_postfix and args.experiment_path:
trainer = PPOAgent.load(experiment_path=args.experiment_path,
postfix=args.load_postfix,
resume=not args.not_resume)
else:
# Get config
config = get_config(args.config)
trainer = PPOAgent(config)
if args.train:
if trainer.config.env.env_name in envpool.list_all_envs():
envs = envpool.make(trainer.config.env.env_name,
env_type="gymnasium",
num_envs=trainer.config.env.num_envs)
else:
envs = AsyncVectorEnv([make_env(config.env.env_name) for _ in range(config.env.num_envs)])
trainer.step(envs, args.exp_name)
if args.eval:
if trainer.config.env.env_name in custom_env.env_list:
if args.video_path:
env = gym.make(trainer.config.env.env_name, env_config={"render_mode": 'rgb_array'})
env = gym.wrappers.RecordVideo(env, args.video_path)
else:
env = gym.make(trainer.config.env.env_name, env_config={"render_mode": 'human'})
else:
env = create_mujoco_env(trainer.config.env.env_name, video_path=args.video_path)
trainer.play(
env=env,
num_episodes=args.eval_n_episode,
max_ep_len=2048
)
if __name__ == "__main__":
main()