1
1
# Copyright (c) OpenMMLab. All rights reserved.
2
2
from typing import Dict , List , Optional
3
3
4
+ import cv2
4
5
import mmcv
5
6
import numpy as np
7
+ import torch
6
8
from mmengine .dist import master_only
7
9
from mmengine .structures import PixelData
8
10
from mmengine .visualization import Visualizer
@@ -42,8 +44,8 @@ class SegLocalVisualizer(Visualizer):
42
44
>>> import numpy as np
43
45
>>> import torch
44
46
>>> from mmengine.structures import PixelData
45
- >>> from mmseg.data import SegDataSample
46
- >>> from mmseg.engine. visualization import SegLocalVisualizer
47
+ >>> from mmseg.structures import SegDataSample
48
+ >>> from mmseg.visualization import SegLocalVisualizer
47
49
48
50
>>> seg_local_visualizer = SegLocalVisualizer()
49
51
>>> image = np.random.randint(0, 256,
@@ -60,7 +62,7 @@ class SegLocalVisualizer(Visualizer):
60
62
>>> seg_local_visualizer.add_datasample(
61
63
... 'visualizer_example', image,
62
64
... gt_seg_data_sample, show=True)
63
- """ # noqa
65
+ """ # noqa
64
66
65
67
def __init__ (self ,
66
68
name : str = 'visualizer' ,
@@ -76,9 +78,32 @@ def __init__(self,
76
78
self .alpha : float = alpha
77
79
self .set_dataset_meta (palette , classes , dataset_name )
78
80
79
- def _draw_sem_seg (self , image : np .ndarray , sem_seg : PixelData ,
81
+ def _get_center_loc (self , mask : np .ndarray ) -> np .ndarray :
82
+ """Get semantic seg center coordinate.
83
+
84
+ Args:
85
+ mask: np.ndarray: get from sem_seg
86
+ """
87
+ loc = np .argwhere (mask == 1 )
88
+
89
+ loc_sort = np .array (
90
+ sorted (loc .tolist (), key = lambda row : (row [0 ], row [1 ])))
91
+ y_list = loc_sort [:, 0 ]
92
+ unique , indices , counts = np .unique (
93
+ y_list , return_index = True , return_counts = True )
94
+ y_loc = unique [counts .argmax ()]
95
+ y_most_freq_loc = loc [loc_sort [:, 0 ] == y_loc ]
96
+ center_num = len (y_most_freq_loc ) // 2
97
+ x = y_most_freq_loc [center_num ][1 ]
98
+ y = y_most_freq_loc [center_num ][0 ]
99
+ return np .array ([x , y ])
100
+
101
+ def _draw_sem_seg (self ,
102
+ image : np .ndarray ,
103
+ sem_seg : PixelData ,
80
104
classes : Optional [List ],
81
- palette : Optional [List ]) -> np .ndarray :
105
+ palette : Optional [List ],
106
+ withLabels : Optional [bool ] = True ) -> np .ndarray :
82
107
"""Draw semantic seg of GT or prediction.
83
108
84
109
Args:
@@ -94,6 +119,8 @@ def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
94
119
palette (list, optional): Input palette for result rendering, which
95
120
is a list of color palette responding to the classes.
96
121
Defaults to None.
122
+ withLabels(bool, optional): Add semantic labels in visualization
123
+ result, Default to True.
97
124
98
125
Returns:
99
126
np.ndarray: the drawn image which channel is RGB.
@@ -112,6 +139,43 @@ def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
112
139
for label , color in zip (labels , colors ):
113
140
mask [sem_seg [0 ] == label , :] = color
114
141
142
+ if withLabels :
143
+ font = cv2 .FONT_HERSHEY_SIMPLEX
144
+ # (0,1] to change the size of the text relative to the image
145
+ scale = 0.05
146
+ fontScale = min (image .shape [0 ], image .shape [1 ]) / (25 / scale )
147
+ fontColor = (255 , 255 , 255 )
148
+ if image .shape [0 ] < 300 or image .shape [1 ] < 300 :
149
+ thickness = 1
150
+ rectangleThickness = 1
151
+ else :
152
+ thickness = 2
153
+ rectangleThickness = 2
154
+ lineType = 2
155
+
156
+ if isinstance (sem_seg [0 ], torch .Tensor ):
157
+ masks = sem_seg [0 ].numpy () == labels [:, None , None ]
158
+ else :
159
+ masks = sem_seg [0 ] == labels [:, None , None ]
160
+ masks = masks .astype (np .uint8 )
161
+ for mask_num in range (len (labels )):
162
+ classes_id = labels [mask_num ]
163
+ classes_color = colors [mask_num ]
164
+ loc = self ._get_center_loc (masks [mask_num ])
165
+ text = classes [classes_id ]
166
+ (label_width , label_height ), baseline = cv2 .getTextSize (
167
+ text , font , fontScale , thickness )
168
+ mask = cv2 .rectangle (mask , loc ,
169
+ (loc [0 ] + label_width + baseline ,
170
+ loc [1 ] + label_height + baseline ),
171
+ classes_color , - 1 )
172
+ mask = cv2 .rectangle (mask , loc ,
173
+ (loc [0 ] + label_width + baseline ,
174
+ loc [1 ] + label_height + baseline ),
175
+ (0 , 0 , 0 ), rectangleThickness )
176
+ mask = cv2 .putText (mask , text , (loc [0 ], loc [1 ] + label_height ),
177
+ font , fontScale , fontColor , thickness ,
178
+ lineType )
115
179
color_seg = (image * (1 - self .alpha ) + mask * self .alpha ).astype (
116
180
np .uint8 )
117
181
self .set_image (color_seg )
@@ -137,7 +201,7 @@ def set_dataset_meta(self,
137
201
visulizer will use the meta information of the dataset i.e.
138
202
classes and palette, but the `classes` and `palette` have
139
203
higher priority. Defaults to None.
140
- """ # noqa
204
+ """ # noqa
141
205
# Set default value. When calling
142
206
# `SegLocalVisualizer().dataset_meta=xxx`,
143
207
# it will override the default value.
@@ -161,7 +225,8 @@ def add_datasample(
161
225
wait_time : float = 0 ,
162
226
# TODO: Supported in mmengine's Viusalizer.
163
227
out_file : Optional [str ] = None ,
164
- step : int = 0 ) -> None :
228
+ step : int = 0 ,
229
+ withLabels : Optional [bool ] = True ) -> None :
165
230
"""Draw datasample and save to all backends.
166
231
167
232
- If GT and prediction are plotted at the same time, they are
@@ -187,6 +252,8 @@ def add_datasample(
187
252
wait_time (float): The interval of show (s). Defaults to 0.
188
253
out_file (str): Path to output file. Defaults to None.
189
254
step (int): Global step value to record. Defaults to 0.
255
+ withLabels(bool, optional): Add semantic labels in visualization
256
+ result, Defaults to True.
190
257
"""
191
258
classes = self .dataset_meta .get ('classes' , None )
192
259
palette = self .dataset_meta .get ('palette' , None )
@@ -202,7 +269,7 @@ def add_datasample(
202
269
'segmentation results.'
203
270
gt_img_data = self ._draw_sem_seg (gt_img_data ,
204
271
data_sample .gt_sem_seg , classes ,
205
- palette )
272
+ palette , withLabels )
206
273
207
274
if (draw_pred and data_sample is not None
208
275
and 'pred_sem_seg' in data_sample ):
@@ -213,7 +280,7 @@ def add_datasample(
213
280
'segmentation results.'
214
281
pred_img_data = self ._draw_sem_seg (pred_img_data ,
215
282
data_sample .pred_sem_seg ,
216
- classes , palette )
283
+ classes , palette , withLabels )
217
284
218
285
if gt_img_data is not None and pred_img_data is not None :
219
286
drawn_img = np .concatenate ((gt_img_data , pred_img_data ), axis = 1 )
0 commit comments