-
Notifications
You must be signed in to change notification settings - Fork 202
/
Copy pathimx500_classification_demo.py
executable file
·156 lines (127 loc) · 6.2 KB
/
imx500_classification_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
import sys
import time
from typing import List
import cv2
import numpy as np
from picamera2 import CompletedRequest, MappedArray, Picamera2
from picamera2.devices import IMX500
from picamera2.devices.imx500 import NetworkIntrinsics
from picamera2.devices.imx500.postprocess import softmax
last_detections = []
LABELS = None
class Classification:
def __init__(self, idx: int, score: float):
"""Create a Classification object, recording the idx and score."""
self.idx = idx
self.score = score
def get_label(request: CompletedRequest, idx: int) -> str:
"""Retrieve the label corresponding to the classification index."""
global LABELS
if LABELS is None:
LABELS = intrinsics.labels
assert len(LABELS) in [1000, 1001], "Labels file should contain 1000 or 1001 labels."
output_tensor_size = imx500.get_output_shapes(request.get_metadata())[0][0]
if output_tensor_size == 1000:
LABELS = LABELS[1:] # Ignore the background label if present
return LABELS[idx]
def parse_and_draw_classification_results(request: CompletedRequest):
"""Analyse and draw the classification results in the output tensor."""
results = parse_classification_results(request)
draw_classification_results(request, results)
def parse_classification_results(request: CompletedRequest) -> List[Classification]:
"""Parse the output tensor into the classification results above the threshold."""
global last_detections
np_outputs = imx500.get_outputs(request.get_metadata())
if np_outputs is None:
return last_detections
np_output = np_outputs[0]
if intrinsics.softmax:
np_output = softmax(np_output)
top_indices = np.argpartition(-np_output, 3)[:3] # Get top 3 indices with the highest scores
top_indices = top_indices[np.argsort(-np_output[top_indices])] # Sort the top 3 indices by their scores
last_detections = [Classification(index, np_output[index]) for index in top_indices]
return last_detections
def draw_classification_results(request: CompletedRequest, results: List[Classification], stream: str = "main"):
"""Draw the classification results for this request onto the ISP output."""
with MappedArray(request, stream) as m:
if intrinsics.preserve_aspect_ratio:
# Drawing ROI box
b_x, b_y, b_w, b_h = imx500.get_roi_scaled(request)
color = (255, 0, 0) # red
cv2.putText(m.array, "ROI", (b_x + 5, b_y + 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
cv2.rectangle(m.array, (b_x, b_y), (b_x + b_w, b_y + b_h), (255, 0, 0, 0))
text_left, text_top = b_x, b_y + 20
else:
text_left, text_top = 0, 0
# Drawing labels (in the ROI box if it exists)
for index, result in enumerate(results):
label = get_label(request, idx=result.idx)
text = f"{label}: {result.score:.3f}"
# Calculate text size and position
(text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
text_x = text_left + 5
text_y = text_top + 15 + index * 20
# Create a copy of the array to draw the background with opacity
overlay = m.array.copy()
# Draw the background rectangle on the overlay
cv2.rectangle(overlay,
(text_x, text_y - text_height),
(text_x + text_width, text_y + baseline),
(255, 255, 255), # Background color (white)
cv2.FILLED)
alpha = 0.3
cv2.addWeighted(overlay, alpha, m.array, 1 - alpha, 0, m.array)
# Draw text on top of the background
cv2.putText(m.array, text, (text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
def get_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="Path of the model",
default="/usr/share/imx500-models/imx500_network_mobilenet_v2.rpk")
parser.add_argument("--fps", type=int, help="Frames per second")
parser.add_argument("-s", "--softmax", action=argparse.BooleanOptionalAction, help="Add post-process softmax")
parser.add_argument("-r", "--preserve-aspect-ratio", action=argparse.BooleanOptionalAction,
help="preprocess the image with preserve aspect ratio")
parser.add_argument("--labels", type=str,
help="Path to the labels file")
parser.add_argument("--print-intrinsics", action="store_true",
help="Print JSON network_intrinsics then exit")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
# This must be called before instantiation of Picamera2
imx500 = IMX500(args.model)
intrinsics = imx500.network_intrinsics
if not intrinsics:
intrinsics = NetworkIntrinsics()
intrinsics.task = "classification"
elif intrinsics.task != "classification":
print("Network is not a classification task", file=sys.stderr)
exit()
# Override intrinsics from args
for key, value in vars(args).items():
if key == 'labels' and value is not None:
with open(value, 'r') as f:
intrinsics.labels = f.read().splitlines()
elif hasattr(intrinsics, key) and value is not None:
setattr(intrinsics, key, value)
# Defaults
if intrinsics.labels is None:
with open("assets/imagenet_labels.txt", "r") as f:
intrinsics.labels = f.read().splitlines()
intrinsics.update_with_defaults()
if args.print_intrinsics:
print(intrinsics)
exit()
picam2 = Picamera2(imx500.camera_num)
config = picam2.create_preview_configuration(controls={"FrameRate": intrinsics.inference_rate}, buffer_count=12)
imx500.show_network_fw_progress_bar()
picam2.start(config, show_preview=True)
if intrinsics.preserve_aspect_ratio:
imx500.set_auto_aspect_ratio()
# Register the callback to parse and draw classification results
picam2.pre_callback = parse_and_draw_classification_results
while True:
time.sleep(0.5)