-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsolov2.py
executable file
·949 lines (801 loc) · 42 KB
/
solov2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
# -*- coding: utf-8 -*-
import logging
import math
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
from detectron2.layers import ShapeSpec, batched_nms, cat, paste_masks_in_image
from detectron2.modeling.anchor_generator import DefaultAnchorGenerator
from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances
from detectron2.utils.logger import log_first_n
from fvcore.nn import sigmoid_focal_loss_jit
from .utils import imrescale, center_of_mass, point_nms, mask_nms, matrix_nms
from .loss import dice_loss, FocalLoss
__all__ = ["SOLOv2"]
# projection 2D point to 1D value
PROJECTION=10000
@META_ARCH_REGISTRY.register()
class SOLOv2(nn.Module):
"""
SOLOv2 model. Creates FPN backbone, instance branch for kernels and categories prediction,
mask branch for unified mask features.
Calculates and applies proper losses to class and masks.
"""
def __init__(self, cfg):
super().__init__()
# get the device of the model
self.device = torch.device(cfg.MODEL.DEVICE)
self.scale_ranges = cfg.MODEL.SOLOV2.FPN_SCALE_RANGES
self.strides = cfg.MODEL.SOLOV2.FPN_INSTANCE_STRIDES
self.sigma = cfg.MODEL.SOLOV2.SIGMA
# Instance parameters.
self.num_classes = cfg.MODEL.SOLOV2.NUM_CLASSES
self.num_kernels = cfg.MODEL.SOLOV2.NUM_KERNELS
self.num_grids = cfg.MODEL.SOLOV2.NUM_GRIDS
self.instance_in_features = cfg.MODEL.SOLOV2.INSTANCE_IN_FEATURES
self.instance_strides = cfg.MODEL.SOLOV2.FPN_INSTANCE_STRIDES
self.instance_in_channels = cfg.MODEL.SOLOV2.INSTANCE_IN_CHANNELS # = fpn.
self.instance_channels = cfg.MODEL.SOLOV2.INSTANCE_CHANNELS
# Mask parameters.
self.mask_on = cfg.MODEL.MASK_ON
self.mask_in_features = cfg.MODEL.SOLOV2.MASK_IN_FEATURES
self.mask_in_channels = cfg.MODEL.SOLOV2.MASK_IN_CHANNELS
self.mask_channels = cfg.MODEL.SOLOV2.MASK_CHANNELS
self.num_masks = cfg.MODEL.SOLOV2.NUM_MASKS
# Inference parameters.
self.max_before_nms = cfg.MODEL.SOLOV2.NMS_PRE
self.score_threshold = cfg.MODEL.SOLOV2.SCORE_THR
self.update_threshold = cfg.MODEL.SOLOV2.UPDATE_THR
self.mask_threshold = cfg.MODEL.SOLOV2.MASK_THR
self.max_per_img = cfg.MODEL.SOLOV2.MAX_PER_IMG
self.nms_kernel = cfg.MODEL.SOLOV2.NMS_KERNEL
self.nms_sigma = cfg.MODEL.SOLOV2.NMS_SIGMA
self.nms_type = cfg.MODEL.SOLOV2.NMS_TYPE
self.prompt = cfg.MODEL.SOLOV2.PROMPT
self.eval_pseudo_label = cfg.MODEL.SOLOV2.EVAL_PSEUDO_LABEL
# build the backbone.
self.backbone = build_backbone(cfg)
backbone_shape = self.backbone.output_shape()
# build the ins head.
instance_shapes = [backbone_shape[f] for f in self.instance_in_features]
self.ins_head = SOLOv2InsHead(cfg, instance_shapes)
# build the mask head.
mask_shapes = [backbone_shape[f] for f in self.mask_in_features]
self.mask_head = SOLOv2MaskHead(cfg, mask_shapes)
# loss
self.ins_loss_weight = cfg.MODEL.SOLOV2.LOSS.DICE_WEIGHT
self.focal_loss_alpha = cfg.MODEL.SOLOV2.LOSS.FOCAL_ALPHA
self.focal_loss_gamma = cfg.MODEL.SOLOV2.LOSS.FOCAL_GAMMA
self.focal_loss_weight = cfg.MODEL.SOLOV2.LOSS.FOCAL_WEIGHT
# image transform
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
self.to(self.device)
def forward(self, batched_inputs):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DetectionTransform` .
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
image: Tensor, image in (C, H, W) format.
instances: Instances
Other information that's included in the original dicts, such as:
"height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
Returns:
losses (dict[str: Tensor]): mapping from a named loss to a tensor
storing the loss. Used during training only.
"""
images = self.preprocess_image(batched_inputs)
if "instances" in batched_inputs[0]:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
elif "targets" in batched_inputs[0]:
log_first_n(
logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10
)
gt_instances = [x["targets"].to(self.device) for x in batched_inputs]
else:
gt_instances = None
features = self.backbone(images.tensor)
# ins branch
ins_features = [features[f] for f in self.instance_in_features]
ins_features = self.split_feats(ins_features)
cate_pred, kernel_pred = self.ins_head(ins_features)
# mask branch
mask_features = [features[f] for f in self.mask_in_features]
mask_pred = self.mask_head(mask_features)
if self.training:
"""
get_ground_truth.
return loss and so on.
"""
mask_feat_size = mask_pred.size()[-2:]
targets = self.get_ground_truth(gt_instances, mask_feat_size)
losses = self.loss(cate_pred, kernel_pred, mask_pred, targets)
return losses
else:
if self.prompt == "none":
# point nms.
cate_pred = [point_nms(cate_p.sigmoid(), kernel=2).permute(0, 2, 3, 1) for cate_p in cate_pred]
# do inference for results.
results = self.inference(cate_pred, kernel_pred, mask_pred, images.image_sizes, batched_inputs, cate_pred)
else:
"""
Implementation for PointWSSIS (https://arxiv.org/abs/2303.15062),
Weakly Semi-Supervised Instance Segmentation with Point Labels.
"""
# do inference with prompt.
targets = self.get_ground_truth(gt_instances, mask_pred.size()[-2:])
points = [t.unsqueeze(0) for target in targets[5] for t in target]
# prompt -> image-level labels : removing misclassified proposals
if self.prompt == "cls":
# point nms.
cate_pred = [point_nms(cate_p.sigmoid(), kernel=2).permute(0, 2, 3, 1)
for cate_p in cate_pred]
cate_gt = []
for cat in targets[1]:
for c in cat:
c_onehot = torch.nn.functional.one_hot(c, num_classes=self.num_classes+1)
c_onehot = c_onehot[:, :, :self.num_classes].unsqueeze(0).float()
cate_gt.append(c_onehot)
cate_pred = [cat_pred * cat_gt for cat_pred, cat_gt in zip(cate_pred, cate_gt)] # score filtering with gt cls labels
results = self.inference(cate_pred, kernel_pred, mask_pred, images.image_sizes, batched_inputs, points)
# prompt -> point labels : removing false-positive and false-negative proposals
elif self.prompt == "point":
point_encodings_flatten = [t.unsqueeze(0) for target in targets[4] for t in target ]
cate_gt = []
for cat in targets[1]:
for c in cat:
c_onehot = torch.nn.functional.one_hot(c, num_classes=self.num_classes+1)
c_onehot = c_onehot[:, :, :self.num_classes].unsqueeze(0).float()
cate_gt.append(c_onehot)
point_encodings = []
for cate_p, cate_g, point in zip(cate_pred, cate_gt, point_encodings_flatten):
# cate_p : [1, C, S, S]
# cate_g : [1, S, S, C]
# point : [1, S, S, C]
for p, c_nonzero in zip(point[cate_g.nonzero(as_tuple=True)], cate_g.nonzero()):
cy = float(p // PROJECTION) / float(PROJECTION)
cx = float(p % PROJECTION) / float(PROJECTION)
point_encodings.append((c_nonzero[0].item(), cy, cx, c_nonzero[3].item()))
point_encodings = list(set(point_encodings))
new_cate_gt = [torch.zeros_like(gt) for gt in cate_gt]
for p_encoding in point_encodings:
max_score, max_fpn_level, max_idx = -1, 0, (0,0,0,0)
for fpn_level, (cate_p, cate_g, kernel_p) in enumerate(zip(cate_pred, cate_gt, kernel_pred)):
cate_p = cate_p.sigmoid().permute(0, 2, 3, 1)
s = cate_p.size(2) # grid size
idx = tuple(map(int, [
p_encoding[0],
max(0, min(s-1, int( p_encoding[1] // (1. / s) ))),
max(0, min(s-1, int( p_encoding[2] // (1. / s) ))),
p_encoding[3],
]))
# obtaining kernels decoded from the point
kernel = kernel_p.permute(0,2,3,1)[idx[:-1]]
kernel = kernel[None, :, None, None]
# from point to mask
seg_pred = F.conv2d(mask_pred, kernel, stride=1).squeeze(0).sigmoid()
seg_mask = seg_pred > self.mask_threshold
sum_mask = seg_mask.sum((1, 2)).float()
seg_score = (seg_pred * seg_mask.float()).sum((1, 2)) / (sum_mask+1e-8)
score = seg_score * cate_p[idx]
# adaptive pyramid selection: get max-scoring fpn-level
if score > max_score:
max_score = score
max_fpn_level = fpn_level
max_idx = idx
if max_score > 0:
new_cate_gt[max_fpn_level][max_idx] = 1.
results = self.inference(new_cate_gt, kernel_pred, mask_pred, images.image_sizes, batched_inputs, points)
# prompt -> point labels including instance size information
elif self.prompt == "point_with_size":
cate_gt = []
for cat in targets[1]:
for c in cat:
c_onehot = torch.nn.functional.one_hot(c, num_classes=self.num_classes+1)
c_onehot = c_onehot[:, :, :self.num_classes].unsqueeze(0).float()
cate_gt.append(c_onehot)
results = self.inference(cate_gt, kernel_pred, mask_pred, images.image_sizes, batched_inputs, points)
else:
raise "[ASSERT] MODEL.SOLOV2.PROMPT can be none, cls, point, point_with_size"
return results
def preprocess_image(self, batched_inputs):
"""
Normalize, pad and batch the input images.
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = [self.normalizer(x) for x in images]
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
return images
@torch.no_grad()
def get_ground_truth(self, gt_instances, mask_feat_size=None):
ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list, cate_point_encoding_list, cate_point_original_list = [], [], [], [], [], []
for img_idx in range(len(gt_instances)):
cur_ins_label_list, cur_cate_label_list, \
cur_ins_ind_label_list, cur_grid_order_list, cur_cate_point_encoding_list, cur_cate_point_original_list = \
self.get_ground_truth_single(img_idx, gt_instances,
mask_feat_size=mask_feat_size)
ins_label_list.append(cur_ins_label_list)
cate_label_list.append(cur_cate_label_list)
ins_ind_label_list.append(cur_ins_ind_label_list)
grid_order_list.append(cur_grid_order_list)
cate_point_encoding_list.append(cur_cate_point_encoding_list)
cate_point_original_list.append(cur_cate_point_original_list)
return ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list, cate_point_encoding_list, cate_point_original_list
def get_ground_truth_single(self, img_idx, gt_instances, mask_feat_size):
gt_bboxes_raw = gt_instances[img_idx].gt_boxes.tensor
gt_labels_raw = gt_instances[img_idx].gt_classes
gt_masks_raw = gt_instances[img_idx].gt_masks.tensor
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
grid_order_list = []
cate_point_encoding_list = []
cate_point_original_list = []
# handling unlabeled data
if gt_labels_raw.shape[0] == 0:
for num_grid in self.num_grids:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=self.device)
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=self.device)
cate_label = torch.fill_(cate_label, self.num_classes)
ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=self.device)
cate_point_encoding = torch.zeros([num_grid, num_grid, self.num_classes], dtype=torch.int64, device=self.device)
cate_point_original = torch.zeros([num_grid, num_grid, self.num_classes], dtype=torch.int64, device=self.device)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
grid_order_list.append([])
cate_point_encoding_list.append(cate_point_encoding)
cate_point_original_list.append(cate_point_original)
return ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list, cate_point_encoding_list, cate_point_original_list
device = gt_labels_raw[0].device
# ins
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
for (lower_bound, upper_bound), stride, num_grid \
in zip(self.scale_ranges, self.strides, self.num_grids):
hit_indices = ((gt_areas >= lower_bound) & (gt_areas < upper_bound)).nonzero().flatten()
num_ins = len(hit_indices)
ins_label = []
grid_order = []
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
cate_label = torch.fill_(cate_label, self.num_classes)
ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device)
cate_point_encoding = torch.zeros([num_grid, num_grid, self.num_classes], dtype=torch.int64, device=device)
cate_point_original = torch.zeros([num_grid, num_grid, self.num_classes], dtype=torch.int64, device=device)
if num_ins == 0:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
grid_order_list.append([])
cate_point_encoding_list.append(cate_point_encoding)
cate_point_original_list.append(cate_point_original)
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices, ...]
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# mass center
mh, mw = gt_masks.shape[1:]
center_ws, center_hs = center_of_mass(gt_masks)
valid_mask_flags = gt_masks.sum(dim=-1).sum(dim=-1) > 0
output_stride = 4
gt_masks = gt_masks.permute(1, 2, 0).to(dtype=torch.uint8).cpu().numpy()
gt_masks = imrescale(gt_masks, scale=1./output_stride)
if len(gt_masks.shape) == 2:
gt_masks = gt_masks[..., None]
gt_masks = torch.from_numpy(gt_masks).to(dtype=torch.uint8, device=device).permute(2, 0, 1)
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
upsampled_size = (mask_feat_size[0] * 4, mask_feat_size[1] * 4)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down -> box for positive region
if not self.training and self.prompt == 'point':
# point positive region (one-pixel positive sample assignment)
top_box = max(0, int(((center_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w) / upsampled_size[1]) // (1. / num_grid)))
else: # box positive region (widen positive sample assignment)
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h-1)
down = min(down_box, coord_h+1)
left = max(coord_w-1, left_box)
right = min(right_box, coord_w+1)
cate_label[top:(down+1), left:(right+1)] = gt_label
enc_ch = center_h / upsampled_size[0] * PROJECTION
enc_cw = center_w / upsampled_size[1] * PROJECTION
cate_point_encoding[top:(down+1), left:(right+1)] = int(enc_ch) * PROJECTION + int(enc_cw)
ori_ch = center_h / mh * PROJECTION
ori_cw = center_w / mw * PROJECTION
cate_point_original[top:(down+1), left:(right+1)] = int(ori_ch) * PROJECTION + int(ori_cw)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
cur_ins_label = torch.zeros([mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8,
device=device)
cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
ins_label.append(cur_ins_label)
ins_ind_label[label] = True
grid_order.append(label)
if len(ins_label) == 0:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
else:
ins_label = torch.stack(ins_label, 0)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
grid_order_list.append(grid_order)
cate_point_encoding_list.append(cate_point_encoding)
cate_point_original_list.append(cate_point_original)
return ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list, cate_point_encoding_list, cate_point_original_list
def loss(self, cate_preds, kernel_preds, ins_pred, targets):
pass
ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list = targets[:4]
# ins
ins_labels = [torch.cat([ins_labels_level_img
for ins_labels_level_img in ins_labels_level], 0)
for ins_labels_level in zip(*ins_label_list)]
kernel_preds = [[kernel_preds_level_img.view(kernel_preds_level_img.shape[0], -1)[:, grid_orders_level_img]
for kernel_preds_level_img, grid_orders_level_img in
zip(kernel_preds_level, grid_orders_level)]
for kernel_preds_level, grid_orders_level in zip(kernel_preds, zip(*grid_order_list))]
# generate masks
ins_pred_list = []
for b_kernel_pred in kernel_preds:
b_mask_pred = []
for idx, kernel_pred in enumerate(b_kernel_pred):
if kernel_pred.size()[-1] == 0:
continue
cur_ins_pred = ins_pred[idx, ...]
H, W = cur_ins_pred.shape[-2:]
N, I = kernel_pred.shape
cur_ins_pred = cur_ins_pred.unsqueeze(0)
kernel_pred = kernel_pred.permute(1, 0).view(I, -1, 1, 1)
cur_ins_pred = F.conv2d(cur_ins_pred, kernel_pred, stride=1).view(-1, H, W)
b_mask_pred.append(cur_ins_pred)
if len(b_mask_pred) == 0:
b_mask_pred = None
else:
b_mask_pred = torch.cat(b_mask_pred, 0)
ins_pred_list.append(b_mask_pred)
ins_ind_labels = [
torch.cat([ins_ind_labels_level_img.flatten()
for ins_ind_labels_level_img in ins_ind_labels_level])
for ins_ind_labels_level in zip(*ins_ind_label_list)
]
flatten_ins_ind_labels = torch.cat(ins_ind_labels)
num_ins = flatten_ins_ind_labels.sum()
# dice loss
loss_ins = []
for input, target in zip(ins_pred_list, ins_labels):
if input is None:
continue
input = torch.sigmoid(input)
loss_ins.append(dice_loss(input, target))
loss_ins_mean = torch.cat(loss_ins).mean()
loss_ins = loss_ins_mean * self.ins_loss_weight
# cate
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
flatten_cate_labels = torch.cat(cate_labels)
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
for cate_pred in cate_preds
]
flatten_cate_preds = torch.cat(cate_preds)
# prepare one_hot
pos_inds = torch.nonzero(flatten_cate_labels != self.num_classes).squeeze(1)
flatten_cate_labels_oh = torch.zeros_like(flatten_cate_preds)
flatten_cate_labels_oh[pos_inds, flatten_cate_labels[pos_inds]] = 1
loss_cate = self.focal_loss_weight * sigmoid_focal_loss_jit(flatten_cate_preds, flatten_cate_labels_oh,
gamma=self.focal_loss_gamma,
alpha=self.focal_loss_alpha,
reduction="sum") / (num_ins + 1)
return {'loss_ins': loss_ins,
'loss_cate': loss_cate}
@staticmethod
def split_feats(feats):
if len(feats) == 5:
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1],
feats[2],
feats[3],
F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))
elif len(feats) == 4:
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1],
feats[2],
feats[3])
elif len(feats) == 3:
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1],
feats[2])
elif len(feats) == 2:
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1])
elif len(feats) == 1:
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'), )
def inference(self, pred_cates, pred_kernels, pred_masks, cur_sizes, images, points):
assert len(pred_cates) == len(pred_kernels)
results = []
num_ins_levels = len(pred_cates)
for img_idx in range(len(images)):
# image size.
ori_img = images[img_idx]
height, width = ori_img["height"], ori_img["width"]
ori_size = (height, width)
# prediction.
pred_cate = [pred_cates[i][img_idx].view(-1, self.num_classes).detach()
for i in range(num_ins_levels)]
pred_kernel = [pred_kernels[i][img_idx].permute(1, 2, 0).view(-1, self.num_kernels).detach()
for i in range(num_ins_levels)]
pred_mask = pred_masks[img_idx, ...].unsqueeze(0)
point = [points[i][img_idx].view(-1, self.num_classes).detach()
for i in range(num_ins_levels)] # beom
pred_cate = torch.cat(pred_cate, dim=0)
pred_kernel = torch.cat(pred_kernel, dim=0)
point = torch.cat(point, dim=0) # beom
# inference for single image.
result = self.inference_single_image(pred_cate, pred_kernel, pred_mask,
cur_sizes[img_idx], ori_size, point)
results.append({"instances": result})
return results
def inference_single_image(
self, cate_preds, kernel_preds, seg_preds, cur_size, ori_size, points
):
# overall info.
h, w = cur_size
f_h, f_w = seg_preds.size()[-2:]
ratio = math.ceil(h/f_h)
upsampled_size_out = (int(f_h*ratio), int(f_w*ratio))
# process.
inds = (cate_preds > self.score_threshold)
cate_scores = cate_preds[inds]
points = points[inds]
if len(cate_scores) == 0:
results = Instances(ori_size)
results.scores = torch.tensor([])
results.pred_classes = torch.tensor([])
results.pred_masks = torch.tensor([])
results.pred_boxes = Boxes(torch.tensor([]))
return results
# cate_labels & kernel_preds
inds = inds.nonzero()
cate_labels = inds[:, 1]
kernel_preds = kernel_preds[inds[:, 0]]
# trans vector.
size_trans = cate_labels.new_tensor(self.num_grids).pow(2).cumsum(0)
strides = kernel_preds.new_ones(size_trans[-1])
n_stage = len(self.num_grids)
strides[:size_trans[0]] *= self.instance_strides[0]
for ind_ in range(1, n_stage):
strides[size_trans[ind_ - 1]:size_trans[ind_]] *= self.instance_strides[ind_]
strides = strides[inds[:, 0]]
# mask encoding.
N, I = kernel_preds.shape
kernel_preds = kernel_preds.view(N, I, 1, 1)
seg_preds = F.conv2d(seg_preds, kernel_preds, stride=1).squeeze(0).sigmoid()
# mask.
seg_masks = seg_preds > self.mask_threshold
sum_masks = seg_masks.sum((1, 2)).float()
# filter.
keep = sum_masks > strides
if keep.sum() == 0:
results = Instances(ori_size)
results.scores = torch.tensor([])
results.pred_classes = torch.tensor([])
results.pred_masks = torch.tensor([])
results.pred_boxes = Boxes(torch.tensor([]))
return results
seg_masks = seg_masks[keep, ...]
seg_preds = seg_preds[keep, ...]
sum_masks = sum_masks[keep]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
points = points[keep]
# maskness.
seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_scores
# sort and keep top nms_pre
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > self.max_before_nms:
sort_inds = sort_inds[:self.max_before_nms]
seg_masks = seg_masks[sort_inds, :, :]
seg_preds = seg_preds[sort_inds, :, :]
sum_masks = sum_masks[sort_inds]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
points = points[sort_inds]
if self.nms_type == "matrix":
# matrix nms & filter.
cate_scores = matrix_nms(cate_labels, seg_masks, sum_masks, cate_scores,
sigma=self.nms_sigma, kernel=self.nms_kernel)
keep = cate_scores >= self.update_threshold
elif self.nms_type == "mask":
# original mask nms.
keep = mask_nms(cate_labels, seg_masks, sum_masks, cate_scores,
nms_thr=self.mask_threshold)
else:
raise NotImplementedError
if keep.sum() == 0:
results = Instances(ori_size)
results.scores = torch.tensor([])
results.pred_classes = torch.tensor([])
results.pred_masks = torch.tensor([])
results.pred_boxes = Boxes(torch.tensor([]))
return results
seg_preds = seg_preds[keep, :, :]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
points = points[keep]
# sort and keep top_k
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > self.max_per_img:
sort_inds = sort_inds[:self.max_per_img]
seg_preds = seg_preds[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
points = points[sort_inds]
# reshape to original size.
seg_preds = F.interpolate(seg_preds.unsqueeze(0),
size=upsampled_size_out,
mode='bilinear')[:, :, :h, :w]
seg_masks = F.interpolate(seg_preds,
size=ori_size,
mode='bilinear').squeeze(0)
seg_masks = seg_masks > self.mask_threshold
if self.eval_pseudo_label:
# set all confidence scores to 1.0
cate_scores = torch.ones_like(cate_scores)
results = Instances(ori_size)
results.pred_classes = cate_labels
results.scores = cate_scores
results.pred_masks = seg_masks
# get bbox from mask
pred_boxes = torch.zeros(seg_masks.size(0), 4)
for i in range(seg_masks.size(0)):
if self.prompt == "none":
mask = seg_masks[i].squeeze()
ys, xs = torch.where(mask)
try:
pred_boxes[i] = torch.tensor([xs.min(), ys.min(), xs.max(), ys.max()]).float()
except:
pred_boxes[i] = torch.tensor([0, 0, 0, 0]).float()
else:
# saving point coordinate for each segment output
cy = points[i] // PROJECTION
cx = points[i] % PROJECTION
pred_boxes[i] = torch.tensor([cy, cx, 0, 0]).float()
results.pred_boxes = Boxes(pred_boxes)
return results
class SOLOv2InsHead(nn.Module):
def __init__(self, cfg, input_shape: List[ShapeSpec]):
"""
SOLOv2 Instance Head.
"""
super().__init__()
# fmt: off
self.num_classes = cfg.MODEL.SOLOV2.NUM_CLASSES
self.num_kernels = cfg.MODEL.SOLOV2.NUM_KERNELS
self.num_grids = cfg.MODEL.SOLOV2.NUM_GRIDS
self.instance_in_features = cfg.MODEL.SOLOV2.INSTANCE_IN_FEATURES
self.instance_strides = cfg.MODEL.SOLOV2.FPN_INSTANCE_STRIDES
self.instance_in_channels = cfg.MODEL.SOLOV2.INSTANCE_IN_CHANNELS # = fpn.
self.instance_channels = cfg.MODEL.SOLOV2.INSTANCE_CHANNELS
# Convolutions to use in the towers
self.type_dcn = cfg.MODEL.SOLOV2.TYPE_DCN
self.num_levels = len(self.instance_in_features)
assert self.num_levels == len(self.instance_strides), \
print("Strides should match the features.")
# fmt: on
head_configs = {"cate": (cfg.MODEL.SOLOV2.NUM_INSTANCE_CONVS,
cfg.MODEL.SOLOV2.USE_DCN_IN_INSTANCE,
False),
"kernel": (cfg.MODEL.SOLOV2.NUM_INSTANCE_CONVS,
cfg.MODEL.SOLOV2.USE_DCN_IN_INSTANCE,
cfg.MODEL.SOLOV2.USE_COORD_CONV)
}
norm = None if cfg.MODEL.SOLOV2.NORM == "none" else cfg.MODEL.SOLOV2.NORM
in_channels = [s.channels for s in input_shape]
assert len(set(in_channels)) == 1, \
print("Each level must have the same channel!")
in_channels = in_channels[0]
assert in_channels == cfg.MODEL.SOLOV2.INSTANCE_IN_CHANNELS, \
print("In channels should equal to tower in channels!")
for head in head_configs:
tower = []
num_convs, use_deformable, use_coord = head_configs[head]
for i in range(num_convs):
conv_func = nn.Conv2d
if i == 0:
if use_coord:
chn = self.instance_in_channels + 2
else:
chn = self.instance_in_channels
else:
chn = self.instance_channels
tower.append(conv_func(
chn, self.instance_channels,
kernel_size=3, stride=1,
padding=1, bias=norm is None
))
if norm == "GN":
tower.append(nn.GroupNorm(32, self.instance_channels))
tower.append(nn.ReLU(inplace=True))
self.add_module('{}_tower'.format(head),
nn.Sequential(*tower))
self.cate_pred = nn.Conv2d(
self.instance_channels, self.num_classes,
kernel_size=3, stride=1, padding=1
)
self.kernel_pred = nn.Conv2d(
self.instance_channels, self.num_kernels,
kernel_size=3, stride=1, padding=1
)
for modules in [
self.cate_tower, self.kernel_tower,
self.cate_pred, self.kernel_pred,
]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
if l.bias is not None:
nn.init.constant_(l.bias, 0)
# initialize the bias for focal loss
prior_prob = cfg.MODEL.SOLOV2.PRIOR_PROB
bias_value = -math.log((1 - prior_prob) / prior_prob)
torch.nn.init.constant_(self.cate_pred.bias, bias_value)
def forward(self, features):
"""
Arguments:
features (list[Tensor]): FPN feature map tensors in high to low resolution.
Each tensor in the list correspond to different feature levels.
Returns:
pass
"""
cate_pred = []
kernel_pred = []
for idx, feature in enumerate(features):
ins_kernel_feat = feature
# concat coord
x_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-1], device=ins_kernel_feat.device)
y_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-2], device=ins_kernel_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_kernel_feat.shape[0], 1, -1, -1])
x = x.expand([ins_kernel_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
ins_kernel_feat = torch.cat([ins_kernel_feat, coord_feat], 1)
# individual feature.
kernel_feat = ins_kernel_feat
seg_num_grid = self.num_grids[idx]
kernel_feat = F.interpolate(kernel_feat, size=seg_num_grid, mode='bilinear')
cate_feat = kernel_feat[:, :-2, :, :]
# kernel
kernel_feat = self.kernel_tower(kernel_feat)
kernel_pred.append(self.kernel_pred(kernel_feat))
# cate
cate_feat = self.cate_tower(cate_feat)
cate_pred.append(self.cate_pred(cate_feat))
return cate_pred, kernel_pred
class SOLOv2MaskHead(nn.Module):
def __init__(self, cfg, input_shape: List[ShapeSpec]):
"""
SOLOv2 Mask Head.
"""
super().__init__()
# fmt: off
self.mask_on = cfg.MODEL.MASK_ON
self.num_masks = cfg.MODEL.SOLOV2.NUM_MASKS
self.mask_in_features = cfg.MODEL.SOLOV2.MASK_IN_FEATURES
self.mask_in_channels = cfg.MODEL.SOLOV2.MASK_IN_CHANNELS
self.mask_channels = cfg.MODEL.SOLOV2.MASK_CHANNELS
self.num_levels = len(input_shape)
assert self.num_levels == len(self.mask_in_features), \
print("Input shape should match the features.")
# fmt: on
norm = None if cfg.MODEL.SOLOV2.NORM == "none" else cfg.MODEL.SOLOV2.NORM
self.convs_all_levels = nn.ModuleList()
for i in range(self.num_levels):
convs_per_level = nn.Sequential()
if i == 0:
conv_tower = list()
conv_tower.append(nn.Conv2d(
self.mask_in_channels, self.mask_channels,
kernel_size=3, stride=1,
padding=1, bias=norm is None
))
if norm == "GN":
conv_tower.append(nn.GroupNorm(32, self.mask_channels))
conv_tower.append(nn.ReLU(inplace=False))
convs_per_level.add_module('conv' + str(i), nn.Sequential(*conv_tower))
self.convs_all_levels.append(convs_per_level)
continue
for j in range(i):
if j == 0:
chn = self.mask_in_channels + 2 if i == 3 else self.mask_in_channels
conv_tower = list()
conv_tower.append(nn.Conv2d(
chn, self.mask_channels,
kernel_size=3, stride=1,
padding=1, bias=norm is None
))
if norm == "GN":
conv_tower.append(nn.GroupNorm(32, self.mask_channels))
conv_tower.append(nn.ReLU(inplace=False))
convs_per_level.add_module('conv' + str(j), nn.Sequential(*conv_tower))
upsample_tower = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False)
convs_per_level.add_module(
'upsample' + str(j), upsample_tower)
continue
conv_tower = list()
conv_tower.append(nn.Conv2d(
self.mask_channels, self.mask_channels,
kernel_size=3, stride=1,
padding=1, bias=norm is None
))
if norm == "GN":
conv_tower.append(nn.GroupNorm(32, self.mask_channels))
conv_tower.append(nn.ReLU(inplace=False))
convs_per_level.add_module('conv' + str(j), nn.Sequential(*conv_tower))
upsample_tower = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False)
convs_per_level.add_module('upsample' + str(j), upsample_tower)
self.convs_all_levels.append(convs_per_level)
self.conv_pred = nn.Sequential(
nn.Conv2d(
self.mask_channels, self.num_masks,
kernel_size=1, stride=1,
padding=0, bias=norm is None),
nn.GroupNorm(32, self.num_masks),
nn.ReLU(inplace=True)
)
for modules in [self.convs_all_levels, self.conv_pred]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
if l.bias is not None:
nn.init.constant_(l.bias, 0)
def forward(self, features):
"""
Arguments:
features (list[Tensor]): FPN feature map tensors in high to low resolution.
Each tensor in the list correspond to different feature levels.
Returns:
pass
"""
assert len(features) == self.num_levels, \
print("The number of input features should be equal to the supposed level.")
# bottom features first.
feature_add_all_level = self.convs_all_levels[0](features[0])
for i in range(1, self.num_levels):
mask_feat = features[i]
if i == 3: # add for coord.
x_range = torch.linspace(-1, 1, mask_feat.shape[-1], device=mask_feat.device)
y_range = torch.linspace(-1, 1, mask_feat.shape[-2], device=mask_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([mask_feat.shape[0], 1, -1, -1])
x = x.expand([mask_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
mask_feat = torch.cat([mask_feat, coord_feat], 1)
# add for top features.
feature_add_all_level = feature_add_all_level + self.convs_all_levels[i](mask_feat)
mask_pred = self.conv_pred(feature_add_all_level)
return mask_pred