-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdraw_box_utils.py
153 lines (134 loc) · 5.9 KB
/
draw_box_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from PIL.Image import Image, fromarray
import PIL.ImageDraw as ImageDraw
import PIL.ImageFont as ImageFont
from PIL import ImageColor
import numpy as np
STANDARD_COLORS = [
'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
'WhiteSmoke', 'Yellow', 'YellowGreen'
]
def draw_text(draw,
box: list,
cls: int,
score: float,
category_index: dict,
color: str,
font: str = 'arial.ttf',
font_size: int = 24):
"""
将目标边界框和类别信息绘制到图片上
"""
try:
font = ImageFont.truetype(font, font_size)
except IOError:
font = ImageFont.load_default()
left, top, right, bottom = box
# If the total height of the display strings added to the top of the bounding
# box exceeds the top of the image, stack the strings below the bounding box
# instead of above.
display_str = f"{category_index[str(cls)]}: {int(100 * score)}%"
display_str_heights = [font.getsize(ds)[1] for ds in display_str]
# Each display_str has a top and bottom margin of 0.05x.
display_str_height = (1 + 2 * 0.05) * max(display_str_heights)
if top > display_str_height:
text_top = top - display_str_height
text_bottom = top
else:
text_top = bottom
text_bottom = bottom + display_str_height
for ds in display_str:
text_width, text_height = font.getsize(ds)
margin = np.ceil(0.05 * text_width)
draw.rectangle([(left, text_top),
(left + text_width + 2 * margin, text_bottom)], fill=color)
draw.text((left + margin, text_top),
ds,
fill='black',
font=font)
left += text_width
def draw_masks(image, masks, colors, thresh: float = 0.7, alpha: float = 0.5):
np_image = np.array(image)
masks = np.where(masks > thresh, True, False)
# colors = np.array(colors)
img_to_draw = np.copy(np_image)
# TODO: There might be a way to vectorize this
for mask, color in zip(masks, colors):
img_to_draw[mask] = color
out = np_image * (1 - alpha) + img_to_draw * alpha
return fromarray(out.astype(np.uint8))
def draw_objs(image: Image,
boxes: np.ndarray = None,
classes: np.ndarray = None,
scores: np.ndarray = None,
masks: np.ndarray = None,
category_index: dict = None,
box_thresh: float = 0.1,
mask_thresh: float = 0.5,
line_thickness: int = 8,
font: str = 'arial.ttf',
font_size: int = 24,
draw_boxes_on_image: bool = True,
draw_masks_on_image: bool = True):
"""
将目标边界框信息,类别信息,mask信息绘制在图片上
Args:
image: 需要绘制的图片
boxes: 目标边界框信息
classes: 目标类别信息
scores: 目标概率信息
masks: 目标mask信息
category_index: 类别与名称字典
box_thresh: 过滤的概率阈值
mask_thresh:
line_thickness: 边界框宽度
font: 字体类型
font_size: 字体大小
draw_boxes_on_image:
draw_masks_on_image:
Returns:
"""
# 过滤掉低概率的目标
idxs = np.greater(scores, box_thresh)
boxes = boxes[idxs]
classes = classes[idxs]
scores = scores[idxs]
if masks is not None:
masks = masks[idxs]
if len(boxes) == 0:
return image
colors = [ImageColor.getrgb(STANDARD_COLORS[cls % len(STANDARD_COLORS)]) for cls in classes]
if draw_boxes_on_image:
# Draw all boxes onto image.
draw = ImageDraw.Draw(image)
for box, cls, score, color in zip(boxes, classes, scores, colors):
left, top, right, bottom = box
# 绘制目标边界框
draw.line([(left, top), (left, bottom), (right, bottom),
(right, top), (left, top)], width=line_thickness, fill=color)
# 绘制类别和概率信息
draw_text(draw, box.tolist(), int(cls), float(score), category_index, color, font, font_size)
if draw_masks_on_image and (masks is not None):
# Draw all mask onto image.
image = draw_masks(image, masks, colors, mask_thresh)
return image