forked from kejingjing88212/MaCA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent_process.py
145 lines (134 loc) · 4.83 KB
/
agent_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""
@author: Gao Fang
@contact: gaofang@cetc.com.cn
@software: PyCharm
@file: fight.py
@time: 2019/3/12 0009 16:41
@desc: agent sub proc
"""
import importlib
import time
from multiprocessing import Process, Queue
import os
AGENT_INIT_TIMEOUT = 60
AGENT_RESP_TIMEOUT = 60
class AgentProc(Process):
'''
agent子进程
'''
def __init__(self, agent_name, size_x, size_y, detector_num, fighter_num, recv_queue, send_queue, gpu_num):
super().__init__()
self.agent_name = agent_name
self.size_x = size_x
self.size_y = size_y
self.detector_num = detector_num
self.fighter_num = fighter_num
self.recv_queue = recv_queue
self.send_queue = send_queue
self.agent = None
self.obs_construct = None
self.obs_ind = 'raw'
self.gpu_num = gpu_num
def run(self):
os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpu_num)
agent_import_path = 'agent.' + self.agent_name + '.agent'
agent_module = importlib.import_module(agent_import_path)
self.agent = agent_module.Agent()
self.agent.set_map_info(self.size_x, self.size_y, self.detector_num,self.fighter_num)
self.obs_ind = self.agent.get_obs_ind()
if self.obs_ind != 'raw':
obs_path = 'obs_construct.' + self.obs_ind + '.construct'
obs_module = importlib.import_module(obs_path)
self.obs_construct = obs_module.ObsConstruct(self.size_x, self.size_y, self.detector_num,self.fighter_num)
# time.sleep(5)
self.send_queue.put('Init OK')
self.__decision_proc()
def __decision_proc(self):
while True:
obs_data = self.recv_queue.get()
if obs_data == 'done':
exit(0)
obs_raw_dict = obs_data['obs_raw_dict']
step_cnt = obs_data['step_cnt']
if self.obs_ind == 'raw':
obs_dict = obs_raw_dict
else:
obs_dict = self.obs_construct.obs_construct(obs_raw_dict)
detector_action, fighter_action = self.agent.get_action(obs_dict, step_cnt)
action_dict = {'detector_action': detector_action, 'fighter_action': fighter_action}
self.send_queue.put(action_dict)
class AgentCtrl:
'''
agent子进程维护
'''
def __init__(self, agent_name, size_x, size_y, detector_num, fighter_num, gpu_num):
self.agent_name = agent_name
self.size_x = size_x
self.size_y = size_y
self.detector_num = detector_num
self.fighter_num = fighter_num
self.send_q = None
self.recv_q = None
self.agent = None
self.gpu_num = gpu_num
def agent_init(self):
self.send_q = Queue(1)
self.recv_q = Queue(1)
self.agent = AgentProc(self.agent_name, self.size_x, self.size_y, self.detector_num, self.fighter_num,
self.send_q, self.recv_q, self.gpu_num)
self.agent.start()
try:
agent_msg = self.recv_q.get(True, AGENT_INIT_TIMEOUT)
except:
self.terminate()
return False
else:
return True
def terminate(self):
if self.agent:
if self.agent.is_alive():
self.send_q.put('done')
self.agent.terminate()
self.agent = None
if self.send_q:
self.send_q.close()
self.send_q = None
if self.recv_q:
self.recv_q.close()
self.recv_q = None
def get_action(self, obs_raw_dict, step_cnt):
'''
获得动作
:param obs_raw_dict: raw obs
:param step_cnt: step计数,从1开始
:return: action: 动作信息
:return: result: 0, 正常; 1, 崩溃; 2, 超时
'''
action = []
result = 0
self.send_q.put({'obs_raw_dict': obs_raw_dict, 'step_cnt': step_cnt})
try:
action = self.recv_q.get(True, AGENT_RESP_TIMEOUT)
except:
if not self.agent.is_alive():
# 子进程不存在,崩溃
result = 1
else:
# 子进程存在,卡死
result = 2
self.__agent_restart()
return action, result
def __agent_restart(self):
'''
重启agent。由于重启代表之前成功启动过,所以此处认为重启也会成功。但若重启不成功将导致程序卡死。若后续出现此类问题应重点排查此处。
:return:
'''
self.terminate()
self.send_q = Queue(1)
self.recv_q = Queue(1)
self.agent = AgentProc(self.agent_name, self.size_x, self.size_y, self.detector_num, self.fighter_num,
self.send_q, self.recv_q, self.gpu_num)
self.agent.start()
self.recv_q.get()