-
Notifications
You must be signed in to change notification settings - Fork 0
/
myclasses.py
405 lines (343 loc) · 17.8 KB
/
myclasses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
from collections import OrderedDict
import torch
from torch.hub import load_state_dict_from_url
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor, MaskRCNNHeads, model_urls
from torchvision.models.detection.rpn import RegionProposalNetwork, concat_box_prediction_layers, AnchorGenerator, \
RPNHead
from torchvision.models.detection.roi_heads import RoIHeads, keypointrcnn_loss, fastrcnn_loss, maskrcnn_inference, \
maskrcnn_loss, keypointrcnn_inference
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
class MyRegionProposalNetwork(RegionProposalNetwork):
""" Add a slight change that allows me to calculate the loss in `eval()` mode.
"""
def forward(self, images, features, targets=None, return_loss=False):
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (List[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (List[Dict[Tensor]): ground-truth boxes present in the image (optional).
If provided, each element in the dict should contain a field `boxes`,
with the locations of the ground-truth boxes.
return_loss (Bool): return the loss (even if we are in eval mode)
Returns:
boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
image.
losses (Dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
# RPN uses all feature maps that are available
features = list(features.values())
objectness, pred_bbox_deltas = self.head(features)
anchors = self.anchor_generator(images, features)
num_images = len(anchors)
num_anchors_per_level = [o[0].numel() for o in objectness]
objectness, pred_bbox_deltas = \
concat_box_prediction_layers(objectness, pred_bbox_deltas)
# apply pred_bbox_deltas to anchors to obtain the decoded proposals
# note that we detach the deltas because Faster R-CNN do not backprop through
# the proposals
proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
proposals = proposals.view(num_images, -1, 4)
boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
losses = {}
if self.training or return_loss:
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
loss_objectness, loss_rpn_box_reg = self.compute_loss(
objectness, pred_bbox_deltas, labels, regression_targets)
losses = {
"loss_objectness": loss_objectness,
"loss_rpn_box_reg": loss_rpn_box_reg,
}
return boxes, losses
class MyRoIHeads(RoIHeads):
""" Add a slight change that allows me to calculate the loss in `eval()` mode.
"""
def forward(self, features, proposals, image_shapes, targets=None, return_loss=False):
"""
Arguments:
features (List[Tensor])
proposals (List[Tensor[N, 4]])
image_shapes (List[Tuple[H, W]])
targets (List[Dict])
return_loss (Bool): return the loss (even if we are in eval mode)
"""
if targets is not None:
for t in targets:
assert t["boxes"].dtype.is_floating_point, 'target boxes must of float type'
assert t["labels"].dtype == torch.int64, 'target labels must of int64 type'
# if self.has_keypoint:
# assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type'
if self.training or return_loss:
proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
box_features = self.box_roi_pool(features, proposals, image_shapes)
box_features = self.box_head(box_features)
class_logits, box_regression = self.box_predictor(box_features)
result, losses = [], {}
if self.training or return_loss:
loss_classifier, loss_box_reg = fastrcnn_loss(
class_logits, box_regression, labels, regression_targets)
losses = dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg)
else:
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
num_images = len(boxes)
for i in range(num_images):
result.append(
dict(
boxes=boxes[i],
labels=labels[i],
scores=scores[i],
)
)
if self.has_mask:
mask_proposals = [p["boxes"] for p in result]
if self.training or return_loss:
# during training, only focus on positive boxes
num_images = len(proposals)
mask_proposals = []
pos_matched_idxs = []
for img_id in range(num_images):
pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
mask_proposals.append(proposals[img_id][pos])
pos_matched_idxs.append(matched_idxs[img_id][pos])
mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
mask_features = self.mask_head(mask_features)
mask_logits = self.mask_predictor(mask_features)
loss_mask = {}
if self.training or return_loss:
gt_masks = [t["masks"] for t in targets]
gt_labels = [t["labels"] for t in targets]
loss_mask = maskrcnn_loss(
mask_logits, mask_proposals,
gt_masks, gt_labels, pos_matched_idxs)
loss_mask = dict(loss_mask=loss_mask)
else:
labels = [r["labels"] for r in result]
masks_probs = maskrcnn_inference(mask_logits, labels)
for mask_prob, r in zip(masks_probs, result):
r["masks"] = mask_prob
losses.update(loss_mask)
if self.has_keypoint:
keypoint_proposals = [p["boxes"] for p in result]
if self.training or return_loss:
# during training, only focus on positive boxes
num_images = len(proposals)
keypoint_proposals = []
pos_matched_idxs = []
for img_id in range(num_images):
pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
keypoint_proposals.append(proposals[img_id][pos])
pos_matched_idxs.append(matched_idxs[img_id][pos])
keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
keypoint_features = self.keypoint_head(keypoint_features)
keypoint_logits = self.keypoint_predictor(keypoint_features)
loss_keypoint = {}
if self.training or return_loss:
gt_keypoints = [t["keypoints"] for t in targets]
loss_keypoint = keypointrcnn_loss(
keypoint_logits, keypoint_proposals,
gt_keypoints, pos_matched_idxs)
loss_keypoint = dict(loss_keypoint=loss_keypoint)
else:
keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
r["keypoints"] = keypoint_prob
r["keypoints_scores"] = kps
losses.update(loss_keypoint)
return result, losses
class MyFasterRCNN(GeneralizedRCNN):
""" The same definition as FasterRCNN(GeneralizedRCNN) but with different:
- forward function: calculation of loss in eval() mode is now possible
- RegionProposalNetwork -> MyRegionProposalNetwork (with modified forward function)
- RoIHeads -> MyRoIHeads (with modified forward function)
"""
def __init__(self, backbone, num_classes=None,
# transform parameters
min_size=800, max_size=1333,
image_mean=None, image_std=None,
# RPN parameters
rpn_anchor_generator=None, rpn_head=None,
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
box_batch_size_per_image=512, box_positive_fraction=0.25,
bbox_reg_weights=None):
if not hasattr(backbone, "out_channels"):
raise ValueError(
"backbone should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the "
"same for all the levels)")
assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None)))
assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None)))
if num_classes is not None:
if box_predictor is not None:
raise ValueError("num_classes should be None when box_predictor is specified")
else:
if box_predictor is None:
raise ValueError("num_classes should not be None when box_predictor "
"is not specified")
out_channels = backbone.out_channels
if rpn_anchor_generator is None:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(
anchor_sizes, aspect_ratios
)
if rpn_head is None:
rpn_head = RPNHead(
out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
)
rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
rpn = MyRegionProposalNetwork(
rpn_anchor_generator, rpn_head,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
if box_roi_pool is None:
box_roi_pool = MultiScaleRoIAlign(
featmap_names=[0, 1, 2, 3],
output_size=7,
sampling_ratio=2)
if box_head is None:
resolution = box_roi_pool.output_size[0]
representation_size = 1024
box_head = TwoMLPHead(
out_channels * resolution ** 2,
representation_size)
if box_predictor is None:
representation_size = 1024
box_predictor = FastRCNNPredictor(
representation_size,
num_classes)
roi_heads = MyRoIHeads(
# Box
box_roi_pool, box_head, box_predictor,
box_fg_iou_thresh, box_bg_iou_thresh,
box_batch_size_per_image, box_positive_fraction,
bbox_reg_weights,
box_score_thresh, box_nms_thresh, box_detections_per_img)
if image_mean is None:
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
super(MyFasterRCNN, self).__init__(backbone, rpn, roi_heads, transform)
def forward(self, images, targets=None, return_loss=False):
"""
Arguments:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
return_loss (Bool): return the loss (even if we are in eval mode)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
original_image_sizes = [img.shape[-2:] for img in images]
images, targets = self.transform(images, targets)
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([(0, features)])
proposals, proposal_losses = self.rpn(images, features, targets, return_loss)
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets, return_loss)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
if self.training or return_loss:
return losses
return detections
class MyMaskRCNN(MyFasterRCNN):
""" the same definition of MaskRCNN(FasterRCNN) but inherited from MyFasterRCNN.
"""
def __init__(self, backbone, num_classes=None,
# transform parameters
min_size=800, max_size=1333,
image_mean=None, image_std=None,
# RPN parameters
rpn_anchor_generator=None, rpn_head=None,
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
box_batch_size_per_image=512, box_positive_fraction=0.25,
bbox_reg_weights=None,
# Mask parameters
mask_roi_pool=None, mask_head=None, mask_predictor=None):
assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None)))
if num_classes is not None:
if mask_predictor is not None:
raise ValueError("num_classes should be None when mask_predictor is specified")
out_channels = backbone.out_channels
if mask_roi_pool is None:
mask_roi_pool = MultiScaleRoIAlign(
featmap_names=[0, 1, 2, 3],
output_size=14,
sampling_ratio=2)
if mask_head is None:
mask_layers = (256, 256, 256, 256)
mask_dilation = 1
mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation)
if mask_predictor is None:
mask_predictor_in_channels = 256 # == mask_layers[-1]
mask_dim_reduced = 256
mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels,
mask_dim_reduced, num_classes)
super(MyMaskRCNN, self).__init__(
backbone, num_classes,
# transform parameters
min_size, max_size,
image_mean, image_std,
# RPN-specific parameters
rpn_anchor_generator, rpn_head,
rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test,
rpn_post_nms_top_n_train, rpn_post_nms_top_n_test,
rpn_nms_thresh,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
# Box parameters
box_roi_pool, box_head, box_predictor,
box_score_thresh, box_nms_thresh, box_detections_per_img,
box_fg_iou_thresh, box_bg_iou_thresh,
box_batch_size_per_image, box_positive_fraction,
bbox_reg_weights)
self.roi_heads.mask_roi_pool = mask_roi_pool
self.roi_heads.mask_head = mask_head
self.roi_heads.mask_predictor = mask_predictor
def my_maskrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, **kwargs):
""" Modified version of maskrcnn_resnet50_fpn() that uses `MyMaskRCNN`
instead of `MaskRCNN`
"""
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
# This is the place where we use our modified model.
model = MyMaskRCNN(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
return model