Skip to content

Commit 4d0b049

Browse files
committed
update rise algo
1 parent 784e039 commit 4d0b049

File tree

8 files changed

+155
-245
lines changed

8 files changed

+155
-245
lines changed

src/datumaro/cli/commands/explain.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
9898
rise_parser.add_argument(
9999
"-s",
100100
"--max-samples",
101-
default=None,
101+
default=100,
102102
type=int,
103103
help="Number of algorithm iterations (default: mask size ^ 2)",
104104
)
@@ -203,13 +203,9 @@ def explain_command(args):
203203

204204
rise = RISE(
205205
model,
206-
max_samples=args.max_samples,
207-
mask_width=args.mask_width,
208-
mask_height=args.mask_height,
206+
num_masks=args.max_samples,
207+
mask_size=args.mask_width,
209208
prob=args.prob,
210-
iou_thresh=args.iou_thresh,
211-
nms_thresh=args.nms_iou_thresh,
212-
det_conf_thresh=args.det_conf_thresh,
213209
batch_size=args.batch_size,
214210
)
215211

+91-182
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,20 @@
1-
# Copyright (C) 2019-2020 Intel Corporation
1+
# Copyright (C) 2019-2024 Intel Corporation
22
#
33
# SPDX-License-Identifier: MIT
44

55
# pylint: disable=unused-variable
66

7-
from math import ceil
8-
7+
import cv2
98
import numpy as np
109

11-
from datumaro.components.annotation import AnnotationType
12-
from datumaro.util.annotation_util import nms
10+
from datumaro.components.dataset import Dataset
11+
from datumaro.components.dataset_base import DatasetItem
12+
from datumaro.components.media import Image
13+
from datumaro.util import take_by
1314

1415
__all__ = ["RISE"]
1516

1617

17-
def _flatmatvec(mat):
18-
return np.reshape(mat, (len(mat), -1))
19-
20-
21-
def _expand(array, axis=None):
22-
if axis is None:
23-
axis = len(array.shape)
24-
return np.expand_dims(array, axis=axis)
25-
26-
2718
class RISE:
2819
"""
2920
Implements RISE: Randomized Input Sampling for
@@ -34,186 +25,104 @@ class RISE:
3425
def __init__(
3526
self,
3627
model,
37-
max_samples=None,
38-
mask_width=7,
39-
mask_height=7,
40-
prob=0.5,
41-
iou_thresh=0.9,
42-
nms_thresh=0.0,
43-
det_conf_thresh=0.0,
44-
batch_size=1,
28+
num_masks: int = 100,
29+
mask_size: int = 7,
30+
prob: float = 0.5,
31+
batch_size: int = 1,
4532
):
33+
assert prob >= 0 and prob <= 1
4634
self.model = model
47-
self.max_samples = max_samples
48-
self.mask_height = mask_height
49-
self.mask_width = mask_width
35+
self.num_masks = num_masks
36+
self.mask_size = mask_size
5037
self.prob = prob
51-
self.iou_thresh = iou_thresh
52-
self.nms_thresh = nms_thresh
53-
self.det_conf_thresh = det_conf_thresh
5438
self.batch_size = batch_size
5539

56-
@staticmethod
57-
def split_outputs(annotations):
58-
labels = []
59-
bboxes = []
60-
for r in annotations:
61-
if r.type is AnnotationType.label:
62-
labels.append(r)
63-
elif r.type is AnnotationType.bbox:
64-
bboxes.append(r)
65-
return labels, bboxes
66-
67-
def normalize_hmaps(self, heatmaps, counts):
68-
eps = np.finfo(heatmaps.dtype).eps
69-
mhmaps = _flatmatvec(heatmaps)
70-
mhmaps /= _expand(counts * self.prob + eps)
71-
mhmaps -= _expand(np.min(mhmaps, axis=1))
72-
mhmaps /= _expand(np.max(mhmaps, axis=1) + eps)
73-
return np.reshape(mhmaps, heatmaps.shape)
40+
def normalize_saliency(self, saliency):
41+
normalized_saliency = np.empty_like(saliency)
42+
for idx, sal in enumerate(saliency):
43+
normalized_saliency[idx, ...] = (sal - np.min(sal)) / (np.max(sal) - np.min(sal))
44+
return normalized_saliency
7445

75-
def apply(self, image, progressive=False):
76-
import cv2
46+
def generate_masks(self, image_size):
47+
cell_size = np.ceil(np.array(image_size) / self.mask_size).astype(np.int8)
48+
up_size = tuple([(self.mask_size + 1) * cs for cs in cell_size])
49+
50+
grid = np.random.rand(self.num_masks, self.mask_size, self.mask_size) < self.prob
51+
grid = grid.astype("float32")
52+
53+
masks = np.empty((self.num_masks, *image_size))
54+
for i in range(self.num_masks):
55+
# Random shifts
56+
x = np.random.randint(0, cell_size[0])
57+
y = np.random.randint(0, cell_size[1])
7758

59+
# Linear upsampling and cropping
60+
masks[i, ...] = cv2.resize(grid[i], up_size, interpolation=cv2.INTER_LINEAR)[
61+
x : x + image_size[0], y : y + image_size[1]
62+
]
63+
64+
return masks
65+
66+
def generate_masked_dataset(self, image, image_size, masks):
67+
input_image = cv2.resize(image, image_size, interpolation=cv2.INTER_LINEAR)
68+
69+
items = []
70+
for id, mask in enumerate(masks):
71+
masked_image = np.expand_dims(mask, axis=-1) * input_image
72+
items.append(
73+
DatasetItem(
74+
id=id,
75+
media=Image.from_numpy(masked_image),
76+
)
77+
)
78+
return Dataset.from_iterable(items)
79+
80+
def apply(self, image, progressive=False):
7881
assert len(image.shape) in [2, 3], "Expected an input image in (H, W, C) format"
7982
if len(image.shape) == 3:
8083
assert image.shape[2] in [3, 4], "Expected BGR or BGRA input"
8184
image = image[:, :, :3].astype(np.float32)
8285

8386
model = self.model
84-
iou_thresh = self.iou_thresh
85-
86-
image_size = np.array((image.shape[:2]))
87-
mask_size = np.array((self.mask_height, self.mask_width))
88-
cell_size = np.ceil(image_size / mask_size)
89-
upsampled_size = np.ceil((mask_size + 1) * cell_size)
90-
91-
rng = lambda shape=None: np.random.rand(*shape)
92-
samples = np.prod(image_size)
93-
if self.max_samples is not None:
94-
samples = min(self.max_samples, samples)
95-
batch_size = self.batch_size
96-
97-
# model is expected to get NxCxHxW shaped input tensor
98-
pred = next(iter(model.infer(_expand(np.transpose(image, (2, 0, 1)), 0))))
99-
result = model.postprocess(pred, None)
100-
result_labels, result_bboxes = self.split_outputs(result)
101-
if 0 < self.det_conf_thresh:
102-
result_bboxes = [
103-
b for b in result_bboxes if self.det_conf_thresh <= b.attributes["score"]
104-
]
105-
if 0 < self.nms_thresh:
106-
result_bboxes = nms(result_bboxes, self.nms_thresh)
107-
108-
predicted_labels = set()
109-
if len(result_labels) != 0:
110-
predicted_label = max(result_labels, key=lambda r: r.attributes["score"]).label
111-
predicted_labels.add(predicted_label)
112-
if len(result_bboxes) != 0:
113-
for bbox in result_bboxes:
114-
predicted_labels.add(bbox.label)
115-
predicted_labels = {label: idx for idx, label in enumerate(predicted_labels)}
116-
117-
predicted_bboxes = result_bboxes
118-
119-
heatmaps_count = len(predicted_labels) + len(predicted_bboxes)
120-
heatmaps = np.zeros((heatmaps_count, *image_size), dtype=np.float32)
121-
total_counts = np.zeros(heatmaps_count, dtype=np.int32)
122-
confs = np.zeros(heatmaps_count, dtype=np.float32)
123-
124-
heatmap_id = 0
125-
126-
# label_heatmaps = None
127-
label_total_counts = None
128-
label_confs = None
129-
if len(predicted_labels) != 0:
130-
step = len(predicted_labels)
131-
# label_heatmaps = heatmaps[heatmap_id : heatmap_id + step]
132-
label_total_counts = total_counts[heatmap_id : heatmap_id + step]
133-
label_confs = confs[heatmap_id : heatmap_id + step]
134-
heatmap_id += step
135-
136-
# bbox_heatmaps = None
137-
bbox_total_counts = None
138-
bbox_confs = None
139-
if len(predicted_bboxes) != 0:
140-
step = len(predicted_bboxes)
141-
# bbox_heatmaps = heatmaps[heatmap_id : heatmap_id + step]
142-
bbox_total_counts = total_counts[heatmap_id : heatmap_id + step]
143-
bbox_confs = confs[heatmap_id : heatmap_id + step]
144-
heatmap_id += step
145-
146-
ups_mask = np.empty(upsampled_size.astype(int), dtype=np.float32)
147-
masks = np.empty((batch_size, *image_size), dtype=np.float32)
148-
149-
full_batch_inputs = np.empty((batch_size, *image.shape), dtype=np.float32)
150-
current_heatmaps = np.empty_like(heatmaps)
151-
for b in range(ceil(samples / batch_size)):
152-
batch_pos = b * batch_size
153-
current_batch_size = min(samples - batch_pos, batch_size)
154-
155-
batch_masks = masks[:current_batch_size]
156-
for i in range(current_batch_size):
157-
mask = (rng(mask_size) < self.prob).astype(np.float32)
158-
cv2.resize(mask, (int(upsampled_size[1]), int(upsampled_size[0])), ups_mask)
159-
160-
offsets = np.round(rng((2,)) * cell_size)
161-
mask = ups_mask[
162-
int(offsets[0]) : int(image_size[0] + offsets[0]),
163-
int(offsets[1]) : int(image_size[1] + offsets[1]),
164-
]
165-
batch_masks[i] = mask
166-
167-
batch_inputs = full_batch_inputs[:current_batch_size]
168-
np.multiply(_expand(batch_masks), _expand(image, 0), out=batch_inputs)
169-
170-
preds = model.infer(np.transpose(batch_inputs, (0, 3, 1, 2)))
171-
results = [model.postprocess(pred, None) for pred in preds]
172-
for mask, result in zip(batch_masks, results):
173-
result_labels, result_bboxes = self.split_outputs(result)
174-
175-
confs.fill(0)
176-
if len(predicted_labels) != 0:
177-
for r in result_labels:
178-
idx = predicted_labels.get(r.label, None)
179-
if idx is not None:
180-
label_total_counts[idx] += 1
181-
label_confs[idx] += r.attributes["score"]
182-
for r in result_bboxes:
183-
idx = predicted_labels.get(r.label, None)
184-
if idx is not None:
185-
label_total_counts[idx] += 1
186-
label_confs[idx] += r.attributes["score"]
187-
188-
if len(predicted_bboxes) != 0 and len(result_bboxes) != 0:
189-
if 0 < self.det_conf_thresh:
190-
result_bboxes = [
191-
b
192-
for b in result_bboxes
193-
if self.det_conf_thresh <= b.attributes["score"]
194-
]
195-
if 0 < self.nms_thresh:
196-
result_bboxes = nms(result_bboxes, self.nms_thresh)
197-
198-
for detection in result_bboxes:
199-
for pred_idx, pred in enumerate(predicted_bboxes):
200-
if pred.label != detection.label:
201-
continue
202-
203-
iou = pred.iou(detection)
204-
assert iou == -1 or 0 <= iou and iou <= 1
205-
if iou < iou_thresh:
206-
continue
207-
208-
bbox_total_counts[pred_idx] += 1
209-
210-
conf = detection.attributes["score"]
211-
bbox_confs[pred_idx] += conf
212-
213-
np.multiply.outer(confs, mask, out=current_heatmaps)
214-
heatmaps += current_heatmaps
87+
88+
image_size = model.inputs[0].shape
89+
logit_size = model.outputs[0].shape
90+
91+
batch_size = image_size[0]
92+
if image_size[1] in [1, 3]: # for CxHxW
93+
image_size = (image_size[2], image_size[3])
94+
elif image_size[3] in [1, 3]: # for HxWxC
95+
image_size = (image_size[1], image_size[2])
96+
97+
masks = self.generate_masks(image_size=image_size)
98+
masked_dataset = self.generate_masked_dataset(image, image_size, masks)
99+
100+
saliency = np.zeros((logit_size[1], *image_size), dtype=np.float32)
101+
for batch_id, batch in enumerate(take_by(masked_dataset, batch_size)):
102+
outputs = model.launch(batch)
103+
104+
for sample_id in range(len(batch)):
105+
mask = masks[batch_size * batch_id + sample_id]
106+
for class_idx in range(logit_size[1]):
107+
score = outputs[sample_id][class_idx].attributes["score"]
108+
saliency[class_idx, ...] += score * mask
109+
110+
# [TODO] wonjuleee: support DRISE for detection model explainability
111+
# if isinstance(self.target, Label):
112+
# logits = outputs[sample_id][0].vector
113+
# max_score = logits[self.target.label]
114+
# elif isinstance(self.target, Bbox):
115+
# preds = outputs[sample_id][0]
116+
# max_score = 0
117+
# for box in preds:
118+
# if box[0] == self.target.label:
119+
# confidence, box = box[1], box[2]
120+
# score = iou(self.target.get_bbox, box) * confidence
121+
# if score > max_score:
122+
# max_score = score
123+
# saliency += max_score * mask
215124

216125
if progressive:
217-
yield self.normalize_hmaps(heatmaps.copy(), total_counts)
126+
yield self.normalize_saliency(saliency)
218127

219-
yield self.normalize_hmaps(heatmaps, total_counts)
128+
yield self.normalize_saliency(saliency)

src/datumaro/components/shift_analyzer.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
# ruff: noqa: E501
66

7-
import itertools
87
from collections import defaultdict
98
from typing import TYPE_CHECKING, Dict, List, Optional
109

@@ -79,8 +78,9 @@ def get_activation_stats(self, dataset: IDataset) -> RunningStats1D:
7978
running_stats = RunningStats1D()
8079

8180
for batch in take_by(dataset, self._batch_size):
82-
features = self.model.launch(batch)
83-
running_stats.add(list(itertools.chain(*features)))
81+
outputs = self.model.launch(batch)[0]
82+
features = [outputs[-1]] # extracted feature vector of googlenet-v4
83+
running_stats.add(features)
8484

8585
return running_stats
8686

@@ -99,10 +99,11 @@ def get_activation_stats(self, dataset: IDataset) -> Dict[int, RunningStats1D]:
9999
inputs.append(np.atleast_3d(item.media.data))
100100
targets.append(ann.label)
101101

102-
features = self.model.launch(batch)
102+
outputs = self.model.launch(batch)[0]
103+
features = [outputs[-1]] # extracted feature vector of googlenet-v4
103104

104-
for feat, target in zip(features, targets):
105-
running_stats[target].add(feat)
105+
for target in targets:
106+
running_stats[target].add(features)
106107

107108
return running_stats
108109

src/datumaro/plugins/openvino_plugin/launcher.py

+8
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,14 @@ def __init__(
208208
self._check_model_support(self._network, self._device)
209209
self._load_executable_net()
210210

211+
@property
212+
def inputs(self):
213+
return self._network.inputs
214+
215+
@property
216+
def outputs(self):
217+
return self._network.outputs
218+
211219
def _check_model_support(self, net, device):
212220
not_supported_layers = set(
213221
name for name, dev in self._core.query_model(net, device).items() if not dev

0 commit comments

Comments
 (0)