-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
74 lines (61 loc) · 2.38 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import json
import multiprocessing as mp
import torch
import yaml
import math
from multiprocessing import Queue, Manager, Barrier
from multi_task import MultiTasks
from runtime import TaskEngine
from models.common import DetectMultiBackend
def lcm_list(lst: list):
lcm = lst[0]
for i in range(1, len(lst)):
lcm = lcm * lst[i] // math.gcd(lcm, lst[i])
return lcm
if __name__ == '__main__':
mp.set_start_method('spawn')
torch.multiprocessing.set_sharing_strategy('file_system')
filePath = "/home/air/yolov5/our_config.yaml"
with open(file=filePath, mode="rb") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
tasks = config['task']
ddls = config['deadline']
batch = config['streams']
shared_barr = Barrier(len(tasks) + 1)
# Track the execution time and task processing status for each process
results = [Queue() for _ in tasks]
shared_queues = [Queue() for _ in tasks]
processes = list()
mocked_tasks = MultiTasks(queues=shared_queues,
barrier=shared_barr,
lifetime=lcm_list(ddls)*20/1000,
intervals=[d / 1000 for d in ddls],
)
for j, task in enumerate(tasks):
processes.append(TaskEngine(name='task_%d'%(j),
barrier=shared_barr,
result=results[j],
img_queue=shared_queues[j],
batch=batch,
model=task,
ddl=ddls[j],
)
)
mocked_tasks.start()
for ps in processes:
ps.start()
mocked_tasks.join()
for ps in processes:
ps.join()
total_violated = 0
total_task = 0
timing = list()
for result in results:
start, end, violated, task_num = json.loads(result.get())
timing.append(start)
timing.append(end)
total_violated += violated
total_task += task_num
qos_violation = round(total_violated / total_task, 3)
print("Overall time: {} second".format(max(timing) - min(timing)))
print("The number of timed-out task is {} and QoS violation rate: {}%".format(total_violated, 100 * qos_violation))