forked from kimoyami/PRDC
-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
100 lines (85 loc) · 3.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import time
import os
import d4rl
from utils.eval import eval_policy
from utils.config import get_config, save_config
from utils.logger import get_logger, get_writer
from utils.buffer import ReplayBuffer
from prdc import PRDC
if __name__ == "__main__":
start_time = time.time()
out = "result"
os.makedirs(out, exist_ok=True)
args, env, kwargs = get_config("PRDC")
result_dir = os.path.join(
out,
time.strftime("%m-%d-%H:%M:%S")
+ "_"
+ args.policy
+ "_"
+ args.env_id
+ "_"
+ str(args.seed),
)
writer = get_writer(result_dir)
file_name = f"{args.policy}_{args.env_id}_{args.seed}"
logger = get_logger(os.path.join(result_dir, file_name + ".log"))
logger.info(
f"Policy: {args.policy}, Env: {args.env_id}, Seed: {args.seed}, Info: {args.info}"
)
# save configs
save_config(args, os.path.join(result_dir, "config.txt"))
# load model
if args.load_model != "default":
model_name = args.load_model
else:
model_name = file_name
ckpt_dir = os.path.join(result_dir, "ckpt")
os.makedirs(ckpt_dir, exist_ok=True)
model_path = os.path.join(ckpt_dir, model_name + ".pth")
replay_buffer = ReplayBuffer(kwargs["state_dim"], kwargs["action_dim"], args.device, args.env_id, args.scale, args.shift)
replay_buffer.convert_D4RL(d4rl.qlearning_dataset(env))
if args.normalize:
mean, std = replay_buffer.normalize_states()
else:
mean, std = 0, 1
states = replay_buffer.state
actions = replay_buffer.action
data = np.hstack([args.beta * states, actions])
policy = PRDC(data, **kwargs)
evaluations = []
evaluation_path = os.path.join(result_dir, file_name + ".npy")
if os.path.exists(model_path):
policy.load(model_path)
for t in range(int(args.max_timesteps)):
result = policy.train(replay_buffer, args.batch_size)
for key, value in result.items():
writer.add_scalar(key, value, global_step=t)
# Evaluate episode
if (t + 1) % args.eval_freq == 0:
model_path = os.path.join(ckpt_dir, model_name + "_" + str(t + 1) + ".pth")
video_path = os.path.join(ckpt_dir, model_name + "_" + str(t + 1) + ".gif")
if args.save_model and (t + 1) % args.save_model_freq == 0:
avg_reward, d4rl_score = eval_policy(
policy,
args.env_id,
args.seed,
mean,
std,
save_gif=False,
video_path=video_path,
)
policy.save(model_path)
else:
avg_reward, d4rl_score = eval_policy(
policy, args.env_id, args.seed, mean, std
)
writer.add_scalar("avg_reward", avg_reward, global_step=t)
writer.add_scalar("d4rl_score", d4rl_score, global_step=t)
evaluations.append(d4rl_score)
logger.info("---------------------------------------")
logger.info(f"Time steps: {t + 1}, D4RL score: {d4rl_score}")
np.save(evaluation_path, evaluations)
end_time = time.time()
logger.info(f"Total Time: {end_time - start_time}")