-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_wrapper.py
51 lines (43 loc) · 1.57 KB
/
model_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import time
import numpy as np
class ModelWrapper:
def __init__(self) -> None:
self.communication_dir = '/comm'
self.input_file = os.path.join(self.communication_dir, 'input.npy')
self.output_file = os.path.join(self.communication_dir, 'output.npy')
self.action_time_step = 0
self.load_model()
self.initialize()
def load_model(self):
raise NotImplementedError
def initialize(self):
raise NotImplementedError
def pred_action(self, input_data):
'''
dict(
extrinsic_cv: np.array([1,3,4])
cam2world_gl: np.array([1,4,4])
instrinsic_cv: np.array([1,3,3])
image: np.array([n,512,512,3])
agent_tcp: np.array([7,])
)
'''
raise NotImplementedError
def kill_signal(self, input_data) -> bool:
if len(input_data) == 0:
return True
return False
def start_service(self):
while True:
if os.path.exists(self.input_file):
time.sleep(1)
input_data = np.load(self.input_file, allow_pickle=True).item()
if not self.kill_signal(input_data):
self.action_time_step += 1
output_data = self.pred_action(input_data)
np.save(self.output_file, output_data)
else:
self.action_time_step = 0 # init
self.initialize()
os.remove(self.input_file)