Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Commit

Permalink
[add] DETR Instance Sesgmentation via Detectron2 (#165)
Browse files Browse the repository at this point in the history
* [add] DETR Instance Sesgmentation via Detectron2

* [change] Following @fmassa comments

* [fix] typo

* Update README.md

Add segmentation training script.

* [delete] import detectron2

* [refactor]

* Update d2/detr/detr.py

Co-authored-by: alcinos <nicolas.carion@ens-lyon.fr>
  • Loading branch information
jd730 and alcinos authored Aug 11, 2020
1 parent 48cf0a7 commit d45a12c
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 16 deletions.
5 changes: 5 additions & 0 deletions d2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ To train DETR on a single node with 8 gpus, simply use:
```
python train_net.py --config configs/detr_256_6_6_torchvision.yaml --num-gpus 8
```

To fine-tune DETR for instance segmentation on a single node with 8 gpus, simply use:
```
python train_net.py --config configs/detr_segm_256_6_6_torchvision.yaml --num-gpus 8 MODEL.DETR.FROZEN_WEIGHTS <model_path>
```
46 changes: 46 additions & 0 deletions d2/configs/detr_segm_256_6_6_torchvision.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
MODEL:
META_ARCHITECTURE: "Detr"
# WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
MASK_ON: True
RESNETS:
DEPTH: 50
STRIDE_IN_1X1: False
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
DETR:
GIOU_WEIGHT: 2.0
L1_WEIGHT: 5.0
NUM_OBJECT_QUERIES: 100
FROZEN_WEIGHTS: ''
DATASETS:
TRAIN: ("coco_2017_train",)
TEST: ("coco_2017_val",)
SOLVER:
IMS_PER_BATCH: 64
BASE_LR: 0.0001
STEPS: (55440,)
MAX_ITER: 92400
WARMUP_FACTOR: 1.0
WARMUP_ITERS: 10
WEIGHT_DECAY: 0.0001
OPTIMIZER: "ADAMW"
BACKBONE_MULTIPLIER: 0.1
CLIP_GRADIENTS:
ENABLED: True
CLIP_TYPE: "full_model"
CLIP_VALUE: 0.01
NORM_TYPE: 2.0
INPUT:
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
CROP:
ENABLED: True
TYPE: "absolute_range"
SIZE: (384, 600)
FORMAT: "RGB"
TEST:
EVAL_PERIOD: 4000
DATALOADER:
FILTER_EMPTY_ANNOTATIONS: False
NUM_WORKERS: 4
VERSION: 2
3 changes: 3 additions & 0 deletions d2/detr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ def add_detr_config(cfg):
cfg.MODEL.DETR = CN()
cfg.MODEL.DETR.NUM_CLASSES = 80

# For Segmentation
cfg.MODEL.DETR.FROZEN_WEIGHTS = ''

# LOSS
cfg.MODEL.DETR.GIOU_WEIGHT = 2.0
cfg.MODEL.DETR.L1_WEIGHT = 5.0
Expand Down
6 changes: 3 additions & 3 deletions d2/detr/dataset_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def __init__(self, cfg, is_train=True):
else:
self.crop_gen = None

assert not cfg.MODEL.MASK_ON, "Mask is not supported"

self.mask_on = cfg.MODEL.MASK_ON
self.tfm_gens = build_transform_gen(cfg, is_train)
logging.getLogger(__name__).info(
"Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen))
Expand Down Expand Up @@ -108,7 +107,8 @@ def __call__(self, dataset_dict):
if "annotations" in dataset_dict:
# USER: Modify this if you want to keep them for some reason.
for anno in dataset_dict["annotations"]:
anno.pop("segmentation", None)
if not self.mask_on:
anno.pop("segmentation", None)
anno.pop("keypoints", None)

# USER: Implement additional transformations if you have other types of data
Expand Down
45 changes: 39 additions & 6 deletions d2/detr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@

from detectron2.layers import ShapeSpec
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess
from detectron2.structures import Boxes, ImageList, Instances
from detectron2.structures import Boxes, ImageList, Instances, BitMasks, PolygonMasks
from detectron2.utils.logger import log_first_n
from fvcore.nn import giou_loss, smooth_l1_loss
from models.backbone import Joiner
from models.detr import DETR, SetCriterion
from models.matcher import HungarianMatcher
from models.position_encoding import PositionEmbeddingSine
from models.transformer import Transformer
from models.segmentation import DETRsegm, PostProcessPanoptic, PostProcessSegm
from util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
from util.misc import NestedTensor
from datasets.coco import convert_coco_poly_to_mask

__all__ = ["Detr"]

Expand Down Expand Up @@ -76,6 +78,7 @@ def __init__(self, cfg):
self.device = torch.device(cfg.MODEL.DEVICE)

self.num_classes = cfg.MODEL.DETR.NUM_CLASSES
self.mask_on = cfg.MODEL.MASK_ON
hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM
num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES
# Transformer parameters:
Expand Down Expand Up @@ -111,6 +114,23 @@ def __init__(self, cfg):
self.detr = DETR(
backbone, transformer, num_classes=self.num_classes, num_queries=num_queries, aux_loss=deep_supervision
)
if self.mask_on:
frozen_weights = cfg.MODEL.DETR.FROZEN_WEIGHTS
if frozen_weights != '':
print("LOAD pre-trained weights")
weight = torch.load(frozen_weights, map_location=lambda storage, loc: storage)['model']
new_weight = {}
for k, v in weight.items():
if 'detr.' in k:
new_weight[k.replace('detr.', '')] = v
else:
print(f"Skipping loading weight {k} from frozen model")
del weight
self.detr.load_state_dict(new_weight)
del new_weight
self.detr = DETRsegm(self.detr, freeze_detr=(frozen_weights != ''))
self.seg_postprocess = PostProcessSegm

self.detr.to(self.device)

# building criterion
Expand All @@ -123,8 +143,10 @@ def __init__(self, cfg):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
losses = ["labels", "boxes", "cardinality"]
if self.mask_on:
losses += ["masks"]
self.criterion = SetCriterion(
self.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses
self.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses,
)
self.criterion.to(self.device)

Expand Down Expand Up @@ -167,7 +189,8 @@ def forward(self, batched_inputs):
else:
box_cls = output["pred_logits"]
box_pred = output["pred_boxes"]
results = self.inference(box_cls, box_pred, images.image_sizes)
mask_pred = output["pred_masks"] if self.mask_on else None
results = self.inference(box_cls, box_pred, mask_pred, images.image_sizes)
processed_results = []
for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes):
height = input_per_image.get("height", image_size[0])
Expand All @@ -185,9 +208,13 @@ def prepare_targets(self, targets):
gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy
gt_boxes = box_xyxy_to_cxcywh(gt_boxes)
new_targets.append({"labels": gt_classes, "boxes": gt_boxes})
if self.mask_on and hasattr(targets_per_image, 'gt_masks'):
gt_masks = targets_per_image.gt_masks
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
new_targets[-1].update({'masks': gt_masks})
return new_targets

def inference(self, box_cls, box_pred, image_sizes):
def inference(self, box_cls, box_pred, mask_pred, image_sizes):
"""
Arguments:
box_cls (Tensor): tensor of shape (batch_size, num_queries, K).
Expand All @@ -206,13 +233,19 @@ def inference(self, box_cls, box_pred, image_sizes):
# For each box we assign the best class or the second best if the best on is `no_object`.
scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1)

for scores_per_image, labels_per_image, box_pred_per_image, image_size in zip(
for i, (scores_per_image, labels_per_image, box_pred_per_image, image_size) in enumerate(zip(
scores, labels, box_pred, image_sizes
):
)):
result = Instances(image_size)
result.pred_boxes = Boxes(box_cxcywh_to_xyxy(box_pred_per_image))

result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0])
if self.mask_on:
mask = F.interpolate(mask_pred[i].unsqueeze(0), size=image_size, mode='bilinear', align_corners=False)
mask = mask[0].sigmoid() > 0.5
B, N, H, W = mask_pred.shape
mask = BitMasks(mask.cpu()).crop_and_resize(result.pred_boxes.tensor.cpu(), 32)
result.pred_masks = mask.unsqueeze(1).to(mask_pred[0].device)

result.scores = scores_per_image
result.pred_classes = labels_per_image
Expand Down
12 changes: 6 additions & 6 deletions models/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,21 @@ def loss_masks(self, outputs, targets, indices, num_boxes):

src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)

src_masks = outputs["pred_masks"]

src_masks = src_masks[src_idx]
masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose()
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_masks = target_masks[tgt_idx]

src_masks = src_masks[src_idx]
# upsample predictions to the target size
src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
mode="bilinear", align_corners=False)
src_masks = src_masks[:, 0].flatten(1)

target_masks = target_masks[tgt_idx].flatten(1)

target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
Expand Down
2 changes: 1 addition & 1 deletion models/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, detr, freeze_detr=False):
self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim)

def forward(self, samples: NestedTensor):
if not isinstance(samples, NestedTensor):
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.detr.backbone(samples)

Expand Down

0 comments on commit d45a12c

Please sign in to comment.