Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #54 from facebookresearch/hubconf_panoptic
Browse files Browse the repository at this point in the history
Hub config for panoptic models
  • Loading branch information
alcinos authored Jun 4, 2020
2 parents a5cd934 + ee75d89 commit bc3c7e3
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 59 deletions.
9 changes: 8 additions & 1 deletion engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,15 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out
res = {target['image_id'].item(): output for target, output in zip(targets, results)}
if coco_evaluator is not None:
coco_evaluator.update(res)

if panoptic_evaluator is not None:
res_pano = postprocessors["panoptic"](outputs, targets)
res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
for i, target in enumerate(targets):
image_id = target["image_id"].item()
file_name = f"{image_id:012d}.png"
res_pano[i]["image_id"] = image_id
res_pano[i]["file_name"] = file_name

panoptic_evaluator.update(res_pano)

# gather the stats from all processes
Expand Down
144 changes: 113 additions & 31 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,168 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from models.detr import DETR

from models.backbone import Backbone, Joiner
from models.transformer import Transformer
from models.detr import DETR, PostProcess
from models.position_encoding import PositionEmbeddingSine
dependencies = ['torch', 'torchvision']
from models.segmentation import DETRsegm, PostProcessPanoptic
from models.transformer import Transformer

dependencies = ["torch", "torchvision"]

def _make_detr(backbone_name: str, dilation=False, num_classes=91):

def _make_detr(backbone_name: str, dilation=False, num_classes=91, mask=False):
hidden_dim = 256
backbone = Backbone(backbone_name, train_backbone=True,
return_interm_layers=False, dilation=dilation)
backbone = Backbone(backbone_name, train_backbone=True, return_interm_layers=False, dilation=dilation)
pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
backbone_with_pos_enc = Joiner(backbone, pos_enc)
backbone_with_pos_enc.num_channels = backbone.num_channels
transformer = Transformer(d_model=hidden_dim, return_intermediate_dec=True)
return DETR(backbone_with_pos_enc, transformer, num_classes=num_classes, num_queries=100)
detr = DETR(backbone_with_pos_enc, transformer, num_classes=num_classes, num_queries=100)
if mask:
return DETRsegm(detr)
return detr


def detr_resnet50(pretrained=False, num_classes=91):
def detr_resnet50(pretrained=False, num_classes=91, return_postprocessor=False):
"""
DETR R50 with 6 encoder and 6 decoder layers.
Achieves 42/62.4 AP/AP50 on COCO val5k.
"""
model = _make_detr('resnet50', dilation=False, num_classes=num_classes)
model = _make_detr("resnet50", dilation=False, num_classes=num_classes)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth',
map_location='cpu',
check_hash=True)
model.load_state_dict(checkpoint['model'])
url="https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth", map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcess()
return model


def detr_resnet50_dc5(pretrained=False, num_classes=91):
def detr_resnet50_dc5(pretrained=False, num_classes=91, return_postprocessor=False):
"""
DETR-DC5 R50 with 6 encoder and 6 decoder layers.
The last block of ResNet-50 has dilation to increase
output resolution.
Achieves 43.3/63.1 AP/AP50 on COCO val5k.
"""
model = _make_detr('resnet50', dilation=True, num_classes=num_classes)
model = _make_detr("resnet50", dilation=True, num_classes=num_classes)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url='https://dl.fbaipublicfiles.com/detr/detr-r50-dc5-f0fb7ef5.pth',
map_location='cpu',
check_hash=True)
model.load_state_dict(checkpoint['model'])
url="https://dl.fbaipublicfiles.com/detr/detr-r50-dc5-f0fb7ef5.pth", map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcess()
return model


def detr_resnet101(pretrained=False, num_classes=91):
def detr_resnet101(pretrained=False, num_classes=91, return_postprocessor=False):
"""
DETR-DC5 R101 with 6 encoder and 6 decoder layers.
Achieves 43.5/63.8 AP/AP50 on COCO val5k.
"""
model = _make_detr('resnet101', dilation=False, num_classes=num_classes)
model = _make_detr("resnet101", dilation=False, num_classes=num_classes)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url='https://dl.fbaipublicfiles.com/detr/detr-r101-2c7b67e5.pth',
map_location='cpu',
check_hash=True)
model.load_state_dict(checkpoint['model'])
url="https://dl.fbaipublicfiles.com/detr/detr-r101-2c7b67e5.pth", map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcess()
return model


def detr_resnet101_dc5(pretrained=False, num_classes=91):
def detr_resnet101_dc5(pretrained=False, num_classes=91, return_postprocessor=False):
"""
DETR-DC5 R101 with 6 encoder and 6 decoder layers.
The last block of ResNet-101 has dilation to increase
output resolution.
Achieves 44.9/64.7 AP/AP50 on COCO val5k.
"""
model = _make_detr('resnet101', dilation=True, num_classes=num_classes)
model = _make_detr("resnet101", dilation=True, num_classes=num_classes)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r101-dc5-a2e86def.pth", map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcess()
return model


def detr_resnet50_panoptic(
pretrained=False, num_classes=250, threshold=0.85, return_postprocessor=False
):
"""
DETR R50 with 6 encoder and 6 decoder layers.
Achieves 43.4 PQ on COCO val5k.
threshold is the minimum confidence required for keeping segments in the prediction
"""
model = _make_detr("resnet50", dilation=False, num_classes=num_classes, mask=True)
is_thing_map = {i: i <= 90 for i in range(250)}
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r50-panoptic-00ce5173.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcessPanoptic(is_thing_map, threshold=threshold)
return model


def detr_resnet50_dc5_panoptic(
pretrained=False, num_classes=91, threshold=0.85, return_postprocessor=False
):
"""
DETR-DC5 R50 with 6 encoder and 6 decoder layers.
The last block of ResNet-50 has dilation to increase
output resolution.
Achieves 44.6 on COCO val5k.
threshold is the minimum confidence required for keeping segments in the prediction
"""
model = _make_detr("resnet50", dilation=True, num_classes=num_classes, mask=True)
is_thing_map = {i: i <= 90 for i in range(250)}
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r50-dc5-panoptic-da08f1b1.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcessPanoptic(is_thing_map, threshold=threshold)
return model


def detr_resnet101_panoptic(
pretrained=False, num_classes=91, threshold=0.85, return_postprocessor=False
):
"""
DETR-DC5 R101 with 6 encoder and 6 decoder layers.
Achieves 45.1 PQ on COCO val5k.
threshold is the minimum confidence required for keeping segments in the prediction
"""
model = _make_detr("resnet101", dilation=False, num_classes=num_classes, mask=True)
is_thing_map = {i: i <= 90 for i in range(250)}
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url='https://dl.fbaipublicfiles.com/detr/detr-r101-dc5-a2e86def.pth',
map_location='cpu',
check_hash=True)
model.load_state_dict(checkpoint['model'])
url="https://dl.fbaipublicfiles.com/detr/detr-r101-panoptic-40021d53.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcessPanoptic(is_thing_map, threshold=threshold)
return model
4 changes: 2 additions & 2 deletions models/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def build(args):
aux_loss=args.aux_loss,
)
if args.masks:
model = DETRsegm(model)
model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
matcher = build_matcher(args)
weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
weight_dict['loss_giou'] = args.giou_loss_coef
Expand All @@ -344,6 +344,6 @@ def build(args):
postprocessors['segm'] = PostProcessSegm()
if args.dataset_file == "coco_panoptic":
is_thing_map = {i: i <= 90 for i in range(201)}
postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, True, threshold=0.85)
postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)

return model, criterion, postprocessors
55 changes: 30 additions & 25 deletions models/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,15 @@ def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
assert len(orig_target_sizes) == len(max_target_sizes)
max_h, max_w = max_target_sizes.max(0)[0].tolist()
outputs_masks = outputs["pred_masks"].squeeze(2)
outputs_masks = F.interpolate(
outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
)
outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()

for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
img_h, img_w = t[0], t[1]
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
results[i]["masks"] = F.interpolate(results[i]["masks"].float(),
size=tuple(tt.tolist()),
mode="nearest").byte()
results[i]["masks"] = F.interpolate(
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
).byte()

return results

Expand All @@ -242,33 +240,49 @@ class PostProcessPanoptic(nn.Module):
"""This class converts the output of the model to the final panoptic result, in the format expected by the
coco panoptic API """

def __init__(self, is_thing_map, rescale_to_orig_size=False, threshold=0.85):
def __init__(self, is_thing_map, threshold=0.85):
"""
Parameters:
is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether
the class is a thing (True) or a stuff (False) class
rescale_to_orig_size: If true, we use rescale the prediction to the size of the original image.
Otherwise, we keep the size after data augmentation
threshold: confidence threshold: segments with confidence lower than this will be deleted
"""
super().__init__()
self.rescale_to_orig_size = rescale_to_orig_size
self.threshold = threshold
self.is_thing_map = is_thing_map

def forward(self, outputs, targets):
def forward(self, outputs, processed_sizes, target_sizes=None):
""" This function computes the panoptic prediction from the model's predictions.
Parameters:
outputs: This is a dict coming directly from the model. See the model doc for the content.
processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the
model, ie the size after data augmentation but before batching.
target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size
of each prediction. If left to None, it will default to the processed_sizes
"""
if target_sizes is None:
target_sizes = processed_sizes
assert len(processed_sizes) == len(target_sizes)
out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"]
assert len(out_logits) == len(raw_masks) == len(targets)
assert len(out_logits) == len(raw_masks) == len(target_sizes)
preds = []
for cur_logits, cur_masks, cur_boxes, target in zip(out_logits, raw_masks, raw_boxes, targets):

def to_tuple(tup):
if isinstance(tup, tuple):
return tup
return tuple(tup.cpu().tolist())

for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
):
# we filter empty queries and detection below threshold
scores, labels = cur_logits.softmax(-1).max(-1)
keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold)
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
cur_scores = cur_scores[keep]
cur_classes = cur_classes[keep]
cur_masks = cur_masks[keep]
cur_masks = interpolate(cur_masks[None], tuple(target["size"].tolist()), mode="bilinear").squeeze(0)
cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0)
cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep])

h, w = cur_masks.shape[-2:]
Expand Down Expand Up @@ -301,8 +315,7 @@ def get_ids_area(masks, scores, dedup=False):
for eq_id in equiv:
m_id.masked_fill_(m_id.eq(eq_id), equiv[0])

field = "orig_size" if self.rescale_to_orig_size else "size"
final_h, final_w = target[field].cpu().unbind(0)
final_h, final_w = to_tuple(target_size)

seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy()))
seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST)
Expand Down Expand Up @@ -335,22 +348,14 @@ def get_ids_area(masks, scores, dedup=False):
else:
cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)

image_id = target["image_id"].item()

segments_info = []
for i, a in enumerate(area):
cat = cur_classes[i].item()
segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a})
del cur_classes

file_name = f"{image_id:012d}.png"
with io.BytesIO() as out:
seg_img.save(out, format="PNG")
predictions = {
"image_id": image_id,
"file_name": file_name,
"png_string": out.getvalue(),
"segments_info": segments_info,
}
predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
preds.append(predictions)
return preds

0 comments on commit bc3c7e3

Please sign in to comment.