-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
119 lines (91 loc) · 4.13 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
from datetime import datetime
from tqdm import tqdm
import numpy as np
import torch
import gymnasium as gym
from torch.utils.tensorboard import SummaryWriter
import pathlib, yaml, logging
from omegaconf import OmegaConf
import wandb
from envs.utils import make_env
from agents import PPOagent
def main(cfg):
torch.cuda.set_device(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_envs = cfg.env.num_envs
envs = gym.vector.AsyncVectorEnv([make_env(cfg.env.task, cfg.agent.seed + i, idx, results_dir) for idx, i in enumerate(range(num_envs))])
agent = PPOagent(envs, cfg, device)
if cfg.agent.load_ppo_model:
agent.load_model(cfg.agent.ppo_checkpoint_path, cfg.agent.image_checkpoint_path)
num_steps = cfg.agent.num_steps
batch_size = int(num_steps * num_envs)
num_updates = cfg.agent.total_timesteps // batch_size
global_step = 0
initial_update = 0
obs, _ = envs.reset()
obs, frame = agent.process_obs(obs)
next_done = torch.zeros(num_envs)
for update in range(initial_update, initial_update + num_updates):
for step in tqdm(range(num_steps), desc=f'Update {update+1}/{initial_update + num_updates} ', unit='step', file=sys.stdout):
global_step += 1 * num_envs
action, logprob, _, val = agent.select_action(obs)
next_obs, reward, done, _, info = envs.step(action.cpu().numpy())
agent.store_experience(obs, action, logprob, torch.tensor(reward), next_done, val.squeeze(), frame)
obs, frame = agent.process_obs(next_obs)
next_done = torch.Tensor(done).to(device)
if "final_info" in info:
for ind, agent_info in enumerate(info["final_info"]):
if agent_info is not None:
ep_rew = agent_info["episode"]["r"]
ep_len = agent_info["episode"]["l"]
logging.info(f"global step: {global_step}, agent_id={ind}, reward={ep_rew[-1]}, length={ep_len[-1]}")
writer.add_scalar("charts/episodic_return", ep_rew, global_step)
writer.add_scalar("charts/episodic_length", ep_len, global_step)
agent.learn(last_obs=obs, last_done=next_done, writer=writer, global_step=global_step)
if num_updates < 40:
agent.save_model(update+1)
elif (update + 1) % (num_updates // 40) == 0:
agent.save_model(update+1)
envs.close()
if __name__ == "__main__":
dir_path = pathlib.Path(__file__).parent.resolve()
with open(dir_path.joinpath("config.yaml"), "r") as f: # Change config file, conf_local.yaml
cfg = yaml.safe_load(f)
cfg = OmegaConf.create(cfg)
dname = f"{cfg.env.task.replace(' ', '_')}_{datetime.now().strftime('%m_%d-%H:%M')}"
if cfg.agent.clip_vloss:
dname = dname + "_vclip"
if cfg.agent.return_norm:
dname = dname + "_rnorm"
if cfg.agent.autocast_flag:
dname = dname + "_autocast"
cfg.agent.n_envs = cfg.env.num_envs
cfg.agent.tsk = cfg.env.task
cfg.agent.image_model = cfg.feature_net_kwargs.rgb_feat.image_model
suf_add = f'only-ppo_{cfg.feature_net_kwargs.rgb_feat.image_model}'
if cfg.agent.train_image_model: suf_add = f'ppo-imgenc_{cfg.feature_net_kwargs.rgb_feat.image_model}'
wandb.init(
project=f"Beryllium_{suf_add}", # Change project name
entity=None,
sync_tensorboard=True,
config=dict(cfg.agent),
name=dname,
)
results_dir = f"results/{suf_add}/{dname}"
cfg.results_dir = results_dir
writer = SummaryWriter(results_dir)
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in cfg.agent.items()])),
)
sys.stderr = open(results_dir+'/err.e', 'w')
log_file = f"{cfg.results_dir}/output.log"
logging.basicConfig(
filename=log_file,
format="[%(asctime)s] [%(levelname)8s] --- %(message)s (%(filename)s:%(lineno)s)", datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
filemode='w'
)
main(cfg)
writer.close()