-
Notifications
You must be signed in to change notification settings - Fork 6
/
trainer.py
99 lines (85 loc) · 3.69 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import time
import os
import numpy as np
import torch
from tqdm import tqdm
class Trainer:
def __init__(
self,
algo,
eval_env,
epoch,
step_per_epoch,
rollout_freq,
logger,
log_freq,
eval_episodes=10
):
self.algo = algo
self.eval_env = eval_env
self._epoch = epoch
self._step_per_epoch = step_per_epoch
self._rollout_freq = rollout_freq
self.logger = logger
self._log_freq = log_freq
self._eval_episodes = eval_episodes
def train_dynamics(self):
start_time = time.time()
self.algo.learn_dynamics()
#self.algo.save_dynamics_model(
#save_path=os.path.join(self.logger.writer.get_logdir(), "dynamics_model")
#)
self.algo.save_dynamics_model("dynamics_model")
self.logger.print("total time: {:.3f}s".format(time.time() - start_time))
def train_policy(self):
start_time = time.time()
num_timesteps = 0
# train loop
for e in range(1, self._epoch + 1):
self.algo.policy.train()
with tqdm(total=self._step_per_epoch, desc=f"Epoch #{e}/{self._epoch}") as t:
while t.n < t.total:
if num_timesteps % self._rollout_freq == 0:
self.algo.rollout_transitions()
# update policy by sac
loss = self.algo.learn_policy()
t.set_postfix(**loss)
# log
if num_timesteps % self._log_freq == 0:
for k, v in loss.items():
self.logger.record(k, v, num_timesteps, printed=False)
num_timesteps += 1
t.update(1)
# evaluate current policy
eval_info = self._evaluate()
ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"])
ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"])
self.logger.record("eval/episode_reward", ep_reward_mean, num_timesteps, printed=False)
self.logger.record("eval/episode_length", ep_length_mean, num_timesteps, printed=False)
self.logger.print(f"Epoch #{e}: episode_reward: {ep_reward_mean:.3f} ± {ep_reward_std:.3f}, episode_length: {ep_length_mean:.3f} ± {ep_length_std:.3f}")
# save policy
torch.save(self.algo.policy.state_dict(), os.path.join(self.logger.writer.get_logdir(), "policy.pth"))
self.logger.print("total time: {:.3f}s".format(time.time() - start_time))
def _evaluate(self):
self.algo.policy.eval()
obs = self.eval_env.reset()
eval_ep_info_buffer = []
num_episodes = 0
episode_reward, episode_length = 0, 0
while num_episodes < self._eval_episodes:
action = self.algo.policy.sample_action(obs, deterministic=True)
next_obs, reward, terminal, _ = self.eval_env.step(action)
episode_reward += reward
episode_length += 1
obs = next_obs
if terminal:
eval_ep_info_buffer.append(
{"episode_reward": episode_reward, "episode_length": episode_length}
)
num_episodes +=1
episode_reward, episode_length = 0, 0
obs = self.eval_env.reset()
return {
"eval/episode_reward": [ep_info["episode_reward"] for ep_info in eval_ep_info_buffer],
"eval/episode_length": [ep_info["episode_length"] for ep_info in eval_ep_info_buffer]
}