1
- # Copyright (C) 2019-2020 Intel Corporation
1
+ # Copyright (C) 2019-2024 Intel Corporation
2
2
#
3
3
# SPDX-License-Identifier: MIT
4
4
5
5
# pylint: disable=unused-variable
6
6
7
- from math import ceil
8
-
7
+ import cv2
9
8
import numpy as np
10
9
11
- from datumaro .components .annotation import AnnotationType
12
- from datumaro .util .annotation_util import nms
10
+ from datumaro .components .dataset import Dataset
11
+ from datumaro .components .dataset_base import DatasetItem
12
+ from datumaro .components .media import Image
13
+ from datumaro .util import take_by
13
14
14
15
__all__ = ["RISE" ]
15
16
16
17
17
- def _flatmatvec (mat ):
18
- return np .reshape (mat , (len (mat ), - 1 ))
19
-
20
-
21
- def _expand (array , axis = None ):
22
- if axis is None :
23
- axis = len (array .shape )
24
- return np .expand_dims (array , axis = axis )
25
-
26
-
27
18
class RISE :
28
19
"""
29
20
Implements RISE: Randomized Input Sampling for
@@ -34,186 +25,104 @@ class RISE:
34
25
def __init__ (
35
26
self ,
36
27
model ,
37
- max_samples = None ,
38
- mask_width = 7 ,
39
- mask_height = 7 ,
40
- prob = 0.5 ,
41
- iou_thresh = 0.9 ,
42
- nms_thresh = 0.0 ,
43
- det_conf_thresh = 0.0 ,
44
- batch_size = 1 ,
28
+ num_masks : int = 100 ,
29
+ mask_size : int = 7 ,
30
+ prob : float = 0.5 ,
31
+ batch_size : int = 1 ,
45
32
):
33
+ assert prob >= 0 and prob <= 1
46
34
self .model = model
47
- self .max_samples = max_samples
48
- self .mask_height = mask_height
49
- self .mask_width = mask_width
35
+ self .num_masks = num_masks
36
+ self .mask_size = mask_size
50
37
self .prob = prob
51
- self .iou_thresh = iou_thresh
52
- self .nms_thresh = nms_thresh
53
- self .det_conf_thresh = det_conf_thresh
54
38
self .batch_size = batch_size
55
39
56
- @staticmethod
57
- def split_outputs (annotations ):
58
- labels = []
59
- bboxes = []
60
- for r in annotations :
61
- if r .type is AnnotationType .label :
62
- labels .append (r )
63
- elif r .type is AnnotationType .bbox :
64
- bboxes .append (r )
65
- return labels , bboxes
66
-
67
- def normalize_hmaps (self , heatmaps , counts ):
68
- eps = np .finfo (heatmaps .dtype ).eps
69
- mhmaps = _flatmatvec (heatmaps )
70
- mhmaps /= _expand (counts * self .prob + eps )
71
- mhmaps -= _expand (np .min (mhmaps , axis = 1 ))
72
- mhmaps /= _expand (np .max (mhmaps , axis = 1 ) + eps )
73
- return np .reshape (mhmaps , heatmaps .shape )
40
+ def normalize_saliency (self , saliency ):
41
+ normalized_saliency = np .empty_like (saliency )
42
+ for idx , sal in enumerate (saliency ):
43
+ normalized_saliency [idx , ...] = (sal - np .min (sal )) / (np .max (sal ) - np .min (sal ))
44
+ return normalized_saliency
74
45
75
- def apply (self , image , progressive = False ):
76
- import cv2
46
+ def generate_masks (self , image_size ):
47
+ cell_size = np .ceil (np .array (image_size ) / self .mask_size ).astype (np .int8 )
48
+ up_size = tuple ([(self .mask_size + 1 ) * cs for cs in cell_size ])
49
+
50
+ grid = np .random .rand (self .num_masks , self .mask_size , self .mask_size ) < self .prob
51
+ grid = grid .astype ("float32" )
52
+
53
+ masks = np .empty ((self .num_masks , * image_size ))
54
+ for i in range (self .num_masks ):
55
+ # Random shifts
56
+ x = np .random .randint (0 , cell_size [0 ])
57
+ y = np .random .randint (0 , cell_size [1 ])
77
58
59
+ # Linear upsampling and cropping
60
+ masks [i , ...] = cv2 .resize (grid [i ], up_size , interpolation = cv2 .INTER_LINEAR )[
61
+ x : x + image_size [0 ], y : y + image_size [1 ]
62
+ ]
63
+
64
+ return masks
65
+
66
+ def generate_masked_dataset (self , image , image_size , masks ):
67
+ input_image = cv2 .resize (image , image_size , interpolation = cv2 .INTER_LINEAR )
68
+
69
+ items = []
70
+ for id , mask in enumerate (masks ):
71
+ masked_image = np .expand_dims (mask , axis = - 1 ) * input_image
72
+ items .append (
73
+ DatasetItem (
74
+ id = id ,
75
+ media = Image .from_numpy (masked_image ),
76
+ )
77
+ )
78
+ return Dataset .from_iterable (items )
79
+
80
+ def apply (self , image , progressive = False ):
78
81
assert len (image .shape ) in [2 , 3 ], "Expected an input image in (H, W, C) format"
79
82
if len (image .shape ) == 3 :
80
83
assert image .shape [2 ] in [3 , 4 ], "Expected BGR or BGRA input"
81
84
image = image [:, :, :3 ].astype (np .float32 )
82
85
83
86
model = self .model
84
- iou_thresh = self .iou_thresh
85
-
86
- image_size = np .array ((image .shape [:2 ]))
87
- mask_size = np .array ((self .mask_height , self .mask_width ))
88
- cell_size = np .ceil (image_size / mask_size )
89
- upsampled_size = np .ceil ((mask_size + 1 ) * cell_size )
90
-
91
- rng = lambda shape = None : np .random .rand (* shape )
92
- samples = np .prod (image_size )
93
- if self .max_samples is not None :
94
- samples = min (self .max_samples , samples )
95
- batch_size = self .batch_size
96
-
97
- # model is expected to get NxCxHxW shaped input tensor
98
- pred = next (iter (model .infer (_expand (np .transpose (image , (2 , 0 , 1 )), 0 ))))
99
- result = model .postprocess (pred , None )
100
- result_labels , result_bboxes = self .split_outputs (result )
101
- if 0 < self .det_conf_thresh :
102
- result_bboxes = [
103
- b for b in result_bboxes if self .det_conf_thresh <= b .attributes ["score" ]
104
- ]
105
- if 0 < self .nms_thresh :
106
- result_bboxes = nms (result_bboxes , self .nms_thresh )
107
-
108
- predicted_labels = set ()
109
- if len (result_labels ) != 0 :
110
- predicted_label = max (result_labels , key = lambda r : r .attributes ["score" ]).label
111
- predicted_labels .add (predicted_label )
112
- if len (result_bboxes ) != 0 :
113
- for bbox in result_bboxes :
114
- predicted_labels .add (bbox .label )
115
- predicted_labels = {label : idx for idx , label in enumerate (predicted_labels )}
116
-
117
- predicted_bboxes = result_bboxes
118
-
119
- heatmaps_count = len (predicted_labels ) + len (predicted_bboxes )
120
- heatmaps = np .zeros ((heatmaps_count , * image_size ), dtype = np .float32 )
121
- total_counts = np .zeros (heatmaps_count , dtype = np .int32 )
122
- confs = np .zeros (heatmaps_count , dtype = np .float32 )
123
-
124
- heatmap_id = 0
125
-
126
- # label_heatmaps = None
127
- label_total_counts = None
128
- label_confs = None
129
- if len (predicted_labels ) != 0 :
130
- step = len (predicted_labels )
131
- # label_heatmaps = heatmaps[heatmap_id : heatmap_id + step]
132
- label_total_counts = total_counts [heatmap_id : heatmap_id + step ]
133
- label_confs = confs [heatmap_id : heatmap_id + step ]
134
- heatmap_id += step
135
-
136
- # bbox_heatmaps = None
137
- bbox_total_counts = None
138
- bbox_confs = None
139
- if len (predicted_bboxes ) != 0 :
140
- step = len (predicted_bboxes )
141
- # bbox_heatmaps = heatmaps[heatmap_id : heatmap_id + step]
142
- bbox_total_counts = total_counts [heatmap_id : heatmap_id + step ]
143
- bbox_confs = confs [heatmap_id : heatmap_id + step ]
144
- heatmap_id += step
145
-
146
- ups_mask = np .empty (upsampled_size .astype (int ), dtype = np .float32 )
147
- masks = np .empty ((batch_size , * image_size ), dtype = np .float32 )
148
-
149
- full_batch_inputs = np .empty ((batch_size , * image .shape ), dtype = np .float32 )
150
- current_heatmaps = np .empty_like (heatmaps )
151
- for b in range (ceil (samples / batch_size )):
152
- batch_pos = b * batch_size
153
- current_batch_size = min (samples - batch_pos , batch_size )
154
-
155
- batch_masks = masks [:current_batch_size ]
156
- for i in range (current_batch_size ):
157
- mask = (rng (mask_size ) < self .prob ).astype (np .float32 )
158
- cv2 .resize (mask , (int (upsampled_size [1 ]), int (upsampled_size [0 ])), ups_mask )
159
-
160
- offsets = np .round (rng ((2 ,)) * cell_size )
161
- mask = ups_mask [
162
- int (offsets [0 ]) : int (image_size [0 ] + offsets [0 ]),
163
- int (offsets [1 ]) : int (image_size [1 ] + offsets [1 ]),
164
- ]
165
- batch_masks [i ] = mask
166
-
167
- batch_inputs = full_batch_inputs [:current_batch_size ]
168
- np .multiply (_expand (batch_masks ), _expand (image , 0 ), out = batch_inputs )
169
-
170
- preds = model .infer (np .transpose (batch_inputs , (0 , 3 , 1 , 2 )))
171
- results = [model .postprocess (pred , None ) for pred in preds ]
172
- for mask , result in zip (batch_masks , results ):
173
- result_labels , result_bboxes = self .split_outputs (result )
174
-
175
- confs .fill (0 )
176
- if len (predicted_labels ) != 0 :
177
- for r in result_labels :
178
- idx = predicted_labels .get (r .label , None )
179
- if idx is not None :
180
- label_total_counts [idx ] += 1
181
- label_confs [idx ] += r .attributes ["score" ]
182
- for r in result_bboxes :
183
- idx = predicted_labels .get (r .label , None )
184
- if idx is not None :
185
- label_total_counts [idx ] += 1
186
- label_confs [idx ] += r .attributes ["score" ]
187
-
188
- if len (predicted_bboxes ) != 0 and len (result_bboxes ) != 0 :
189
- if 0 < self .det_conf_thresh :
190
- result_bboxes = [
191
- b
192
- for b in result_bboxes
193
- if self .det_conf_thresh <= b .attributes ["score" ]
194
- ]
195
- if 0 < self .nms_thresh :
196
- result_bboxes = nms (result_bboxes , self .nms_thresh )
197
-
198
- for detection in result_bboxes :
199
- for pred_idx , pred in enumerate (predicted_bboxes ):
200
- if pred .label != detection .label :
201
- continue
202
-
203
- iou = pred .iou (detection )
204
- assert iou == - 1 or 0 <= iou and iou <= 1
205
- if iou < iou_thresh :
206
- continue
207
-
208
- bbox_total_counts [pred_idx ] += 1
209
-
210
- conf = detection .attributes ["score" ]
211
- bbox_confs [pred_idx ] += conf
212
-
213
- np .multiply .outer (confs , mask , out = current_heatmaps )
214
- heatmaps += current_heatmaps
87
+
88
+ image_size = model .inputs [0 ].shape
89
+ logit_size = model .outputs [0 ].shape
90
+
91
+ batch_size = image_size [0 ]
92
+ if image_size [1 ] in [1 , 3 ]: # for CxHxW
93
+ image_size = (image_size [2 ], image_size [3 ])
94
+ elif image_size [3 ] in [1 , 3 ]: # for HxWxC
95
+ image_size = (image_size [1 ], image_size [2 ])
96
+
97
+ masks = self .generate_masks (image_size = image_size )
98
+ masked_dataset = self .generate_masked_dataset (image , image_size , masks )
99
+
100
+ saliency = np .zeros ((logit_size [1 ], * image_size ), dtype = np .float32 )
101
+ for batch_id , batch in enumerate (take_by (masked_dataset , batch_size )):
102
+ outputs = model .launch (batch )
103
+
104
+ for sample_id in range (len (batch )):
105
+ mask = masks [batch_size * batch_id + sample_id ]
106
+ for class_idx in range (logit_size [1 ]):
107
+ score = outputs [sample_id ][class_idx ].attributes ["score" ]
108
+ saliency [class_idx , ...] += score * mask
109
+
110
+ # [TODO] wonjuleee: support DRISE for detection model explainability
111
+ # if isinstance(self.target, Label):
112
+ # logits = outputs[sample_id][0].vector
113
+ # max_score = logits[self.target.label]
114
+ # elif isinstance(self.target, Bbox):
115
+ # preds = outputs[sample_id][0]
116
+ # max_score = 0
117
+ # for box in preds:
118
+ # if box[0] == self.target.label:
119
+ # confidence, box = box[1], box[2]
120
+ # score = iou(self.target.get_bbox, box) * confidence
121
+ # if score > max_score:
122
+ # max_score = score
123
+ # saliency += max_score * mask
215
124
216
125
if progressive :
217
- yield self .normalize_hmaps ( heatmaps . copy (), total_counts )
126
+ yield self .normalize_saliency ( saliency )
218
127
219
- yield self .normalize_hmaps ( heatmaps , total_counts )
128
+ yield self .normalize_saliency ( saliency )
0 commit comments