-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_maxpressure.py
executable file
·118 lines (99 loc) · 4.24 KB
/
run_maxpressure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from utils.utils import oneline_wrapper
import os
import time
from multiprocessing import Process
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--memo", type=str, default='benchmark_1001')
parser.add_argument("-model", type=str, default="MaxPressure")
parser.add_argument("-eightphase", action="store_true", default=False)
parser.add_argument("-multi_process", action="store_true", default=True)
parser.add_argument("-workers", type=int, default=3)
parser.add_argument("-hangzhou", action="store_true", default=False)
parser.add_argument("-jinan", action="store_true", default=True)
return parser.parse_args()
def main(in_args):
if in_args.hangzhou:
count = 3600
road_net = "4_4"
traffic_file_list = ["anon_4_4_hangzhou_real.json",
"anon_4_4_hangzhou_real_5816.json"]
template = "Hangzhou"
elif in_args.jinan:
count = 3600
road_net = "3_4"
traffic_file_list = ["anon_3_4_jinan_real.json", "anon_3_4_jinan_real_2000.json",
"anon_3_4_jinan_real_2500.json"]
template = "Jinan"
NUM_ROW = int(road_net.split('_')[0])
NUM_COL = int(road_net.split('_')[1])
num_intersections = NUM_ROW * NUM_COL
print('num_intersections:', num_intersections)
print(traffic_file_list)
process_list = []
for traffic_file in traffic_file_list:
dic_traffic_env_conf_extra = {
"NUM_AGENTS": num_intersections,
"NUM_INTERSECTIONS": num_intersections,
"MODEL_NAME": in_args.model,
"RUN_COUNTS": count,
"NUM_ROW": NUM_ROW,
"NUM_COL": NUM_COL,
"TRAFFIC_FILE": traffic_file,
"ROADNET_FILE": "roadnet_{0}.json".format(road_net),
"LIST_STATE_FEATURE": [
"cur_phase",
"traffic_movement_pressure_queue",
],
"DIC_REWARD_INFO": {
"pressure": 0
},
}
if in_args.eightphase:
dic_traffic_env_conf_extra["PHASE"] = {
1: [0, 1, 0, 1, 0, 0, 0, 0],
2: [0, 0, 0, 0, 0, 1, 0, 1],
3: [1, 0, 1, 0, 0, 0, 0, 0],
4: [0, 0, 0, 0, 1, 0, 1, 0],
5: [1, 1, 0, 0, 0, 0, 0, 0],
6: [0, 0, 1, 1, 0, 0, 0, 0],
7: [0, 0, 0, 0, 0, 0, 1, 1],
8: [0, 0, 0, 0, 1, 1, 0, 0]
}
dic_traffic_env_conf_extra["PHASE_LIST"] = ['WT_ET', 'NT_ST', 'WL_EL', 'NL_SL',
'WL_WT', 'EL_ET', 'SL_ST', 'NL_NT']
dic_agent_conf_extra = {
"FIXED_TIME": [15, 15, 15, 15],
}
dic_traffic_env_conf_extra["NUM_AGENTS"] = dic_traffic_env_conf_extra["NUM_INTERSECTIONS"]
dic_path_extra = {
"PATH_TO_MODEL": os.path.join("model", in_args.memo, traffic_file + "_" +
time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))),
"PATH_TO_WORK_DIRECTORY": os.path.join("records", in_args.memo, traffic_file + "_" +
time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))),
"PATH_TO_DATA": os.path.join("data", template, str(road_net))
}
if in_args.multi_process:
process_list.append(Process(target=oneline_wrapper,
args=(dic_agent_conf_extra,
dic_traffic_env_conf_extra, dic_path_extra))
)
else:
oneline_wrapper(dic_agent_conf_extra, dic_traffic_env_conf_extra, dic_path_extra)
if in_args.multi_process:
i = 0
list_cur_p = []
for p in process_list:
if len(list_cur_p) < in_args.workers:
print(i)
p.start()
list_cur_p.append(p)
i += 1
if len(list_cur_p) < in_args.workers:
continue
for p in list_cur_p:
p.join()
if __name__ == "__main__":
args = parse_args()
main(args)