-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTestMain.py
64 lines (55 loc) · 2.02 KB
/
TestMain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import glob
import os
import torch
from algorithm.PPO import PPO
from utils.ConfigHelper import ConfigHelper
from utils.Logger import Logger
from utils.obs_2_tensor import _obs_2_tensor
from utils.recurrent_cell_init import recurrent_cell_init
class TestMain:
def __init__(
self,
obs_space: tuple,
action_space: tuple,
conf: ConfigHelper,
ckpt: int = None,
):
self.conf = conf
self.logger = Logger(self.conf, True)
self.state_normalizer = self.logger.load_pickle("state_normalizer.pkl")
self.state_normalizer.config = self.conf
self.agent = PPO(obs_space, action_space, self.conf)
latest_checkpoint = (
max(
glob.glob(os.path.join(self.logger.checkpoint_path, "*")),
key=os.path.getctime,
)
if ckpt is None
else os.path.join(self.logger.checkpoint_path, f"{ckpt}.pth")
)
print(f"resume from {latest_checkpoint}")
self.agent.load(latest_checkpoint)
self.agent.policy.eval()
self.h_in = recurrent_cell_init(
1, self.conf.hidden_state_size, self.conf.layer_type, self.conf.device
)
with torch.no_grad():
self._run()
def _obs_preprocess(self, obs):
state = self.state_normalizer(obs, update=False)
state = _obs_2_tensor(state, self.conf.device)
if len(state[-1].shape) < 2:
state = [s.unsqueeze(0) for s in state]
return state
def reset(self):
self.h_in = recurrent_cell_init(
1, self.conf.hidden_state_size, self.conf.layer_type, self.conf.device
)
def select_action(self, obs, is_ros: bool = False):
obs = self._obs_preprocess(obs)
action, self.h_in, _ = self.agent.eval_select_action(
obs, self.h_in, module_index=-1, is_ros=is_ros
)
return action.cpu().numpy()
def _run():
raise NotImplementedError