Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

43 panoptic quality metrics #89

Merged
merged 24 commits into from
Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6bf95cd
Merge branch 'coco-panoptic-dataset' into 43-panoptic-quality-metrics
Sep 22, 2021
24da3df
implement new labels
Sep 23, 2021
92d9cc4
Merge branch '42-development-panoptic-module' into 43-panoptic-qualit…
Sep 24, 2021
58fbc7c
update view with multiples sets
Sep 24, 2021
f0cefaa
first version pqmetric and callback in train procedure
Sep 27, 2021
4ac6443
new pq and base callback
Sep 27, 2021
a69b092
Merge branch '42-development-panoptic-module' into 43-panoptic-qualit…
Sep 28, 2021
c095e85
Merge branch '42-development-panoptic-module' into 43-panoptic-qualit…
Sep 28, 2021
acecead
Update print result for PQmetrics
Sep 28, 2021
c0e5b5e
fix bug in mask get_view
Sep 30, 2021
59c247a
Merge branch '42-development-panoptic-module' into 43-panoptic-qualit…
Oct 1, 2021
6cf9a50
Merge branch '42-development-panoptic-module' into 43-panoptic-qualit…
Oct 1, 2021
c6dc6c5
Merge branch '42-development-panoptic-module' into 43-panoptic-qualit…
Oct 1, 2021
2d1c6b2
pq metrics for ds without isthing labels
Oct 4, 2021
e2394f3
Merge branch 'master' into 43-panoptic-quality-metrics
Oct 4, 2021
d08dd74
Add new detrR50PanopticFinetune
Oct 5, 2021
43f4015
fix bad pq metric and mask view by object/cat
Oct 5, 2021
fb71c38
fix PQ metric
Oct 5, 2021
af13e9a
Merge pull request #101 from Visual-Behavior/fix_compute_pq_metric_ma…
Johansmm Oct 6, 2021
6c42bd9
Merge branch '43-panoptic-quality-metrics' of github.com:Visual-Behav…
Oct 6, 2021
95732b2
Merge branch 'master' into 43-panoptic-quality-metrics
Oct 6, 2021
f58d5af
load weights for fix_len = None
Oct 6, 2021
8398dac
change baseMetricCallback name
Oct 6, 2021
ef3ea3e
fix classes reindex
Oct 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 56 additions & 13 deletions alodataset/coco_panoptic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import torch
from PIL import Image
from typing import Union

from alodataset.utils.panoptic_utils import rgb2id
from alodataset.utils.panoptic_utils import masks_to_boxes
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
split=Split.TRAIN,
return_masks: bool = True,
classes: list = None,
fix_classes_len: int = 250, # Match with pre-trained weights
fix_classes_len: int = None, # Match with pre-trained weights
**kwargs,
):
super(CocoPanopticDataset, self).__init__(name=name, split=split, **kwargs)
Expand All @@ -82,7 +83,7 @@ def __init__(
self.ann_folder = os.path.join(self.dataset_dir, self.get_split_ann_folder())
self.ann_file = os.path.join(self.dataset_dir, self.get_split_ann_file())
self.return_masks = return_masks
self.label_names = None
self.label_names, self.label_types, self.label_types_names = None, None, None
self.items = self._get_sequences()

# Fix classes if it is desired
Expand Down Expand Up @@ -110,6 +111,15 @@ def __init__(
items.append(self.items[i])
self.items = items

# Fix label_types: If `classes` is desired, remove types that not include this classes and fix indices
if self.label_types is not None:
for ltype, vtype in self.label_types.items():
vtype = [x for b, x in enumerate(vtype) if self._ids_renamed[b] != -1]
ltn = list(sorted(set([self.label_types_names[ltype][vt] for vt in vtype])))
index = {b: ltn.index(p) for b, p in enumerate(self.label_types_names[ltype]) if p in ltn}
self.label_types[ltype] = [index[idx] for idx in vtype]
self.label_types_names[ltype] = ltn

# Fix number of label names if desired
if fix_classes_len is not None:
if fix_classes_len > len(self.label_names):
Expand Down Expand Up @@ -146,9 +156,26 @@ def _get_sequences(self):
if "categories" in coco:
nb_category = max(cat["id"] for cat in coco["categories"])
self.label_names = ["N/A"] * (nb_category + 1)

# Get types names
self.label_types_names = {
k: list(sorted(set([cat[k] for cat in coco["categories"]]))) + ["N/A"]
for k in coco["categories"][0].keys()
if k not in ["id", "name"]
}

# Make index between type category id and label id
self.label_types = {
k: [len(self.label_types_names[k]) - 1] * (nb_category + 1) for k in self.label_types_names
}
if "isthing" in self.label_types_names:
self.label_types_names["isthing"] = ["stuff", "thing", "N/A"]
for cat in coco["categories"]:
self.label_names[cat["id"]] = cat["name"]
print("Done")
for k in self.label_types:
self.label_types[k][cat["id"]] = (
cat[k] if k == "isthing" else self.label_types_names[k].index(cat[k])
)
return items

def get_split_ann_folder(self):
Expand All @@ -173,6 +200,18 @@ def get_split_ann_file(self):
assert self.split in self.SPLIT_ANN_FILES
return self.SPLIT_ANN_FILES[self.split]

def _append_type_labels(self, element: Union[BoundingBoxes2D, Mask], labels):
if self.label_types is not None:
for ktype in self.label_types:
label_types = torch.as_tensor(self.label_types[ktype])[labels]
label_types = Labels(
label_types.to(torch.float32),
labels_names=self.label_types_names[ktype],
names=("N"),
encoding="id",
)
element.append_labels(label_types, name=ktype)

def getitem(self, idx):
"""Get the :mod:`Frame <aloscene.frame>` corresponds to *idx* index

Expand Down Expand Up @@ -213,28 +252,32 @@ def getitem(self, idx):

# Make aloscene.frame
frame = Frame(img_path)

labels_2d = Labels(labels.to(torch.float32), labels_names=self.label_names, names=("N"), encoding="id")
boxes_2d = BoundingBoxes2D(
masks_to_boxes(masks),
boxes_format="xyxy",
absolute=True,
frame_size=frame.HW,
names=("N", None),
labels=labels_2d,
masks_to_boxes(masks), boxes_format="xyxy", absolute=True, frame_size=frame.HW, names=("N", None),
)
boxes_2d.append_labels(labels_2d, name="category")
self._append_type_labels(boxes_2d, labels)
frame.append_boxes2d(boxes_2d)

if self.return_masks:
masks_2d = Mask(masks, names=("N", "H", "W"), labels=labels_2d)
masks_2d = Mask(masks, names=("N", "H", "W"))
masks_2d.append_labels(labels_2d, name="category")
self._append_type_labels(masks_2d, labels)
frame.append_segmentation(masks_2d)
return frame


if __name__ == "__main__":
coco_seg = CocoPanopticDataset(sample=True)
coco_seg = CocoPanopticDataset(sample=False)
for f, frames in enumerate(coco_seg.train_loader(batch_size=2)):
frames = Frame.batch_list(frames)
frames.get_view().render()
labels_set = "category" if isinstance(frames.boxes2d[0].labels, dict) else None
views = [fr.boxes2d.get_view(fr, labels_set=labels_set) for fr in frames]
if hasattr(frames, "segmentation"):
views += [fr.segmentation.get_view(fr, labels_set=labels_set) for fr in frames]
frames.get_view(views).render()
# frames.get_view(labels_set=labels_set).render()

if f > 1:
break
16 changes: 11 additions & 5 deletions alodataset/utils/panoptic_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from typing import Union
import numpy as np
import torch

# from alonet.metrics.compute_pq import VOID

VOID_CLASS_ID = -1
GLOBAL_COLOR_SET = np.random.uniform(0, 1, (300, 3))
GLOBAL_COLOR_SET[VOID_CLASS_ID] = [0, 0, 0]
OFFSET = 256 * 256 * 256


# Function get from PanopticAPI: https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
def rgb2id(color):
def rgb2id(color: Union[list, np.ndarray]):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
Expand All @@ -14,20 +20,20 @@ def rgb2id(color):


# Function get from PanopticAPI: https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
def id2rgb(id_map, random_color=True):
def id2rgb(id_map: np.ndarray, random_color: bool = True):
if random_color:
return (256 * GLOBAL_COLOR_SET[id_map]).astype(np.uint8)
return GLOBAL_COLOR_SET[id_map]
if isinstance(id_map, np.ndarray):
id_map_copy = id_map.copy()
rgb_shape = tuple(list(id_map.shape) + [3])
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
for i in range(3):
rgb_map[..., i] = id_map_copy % 256
id_map_copy //= 256
return rgb_map
return rgb_map / 255.0
color = []
for _ in range(3):
color.append(id_map % 256)
color.append((id_map % 256) / 255.0)
id_map //= 256
return color

Expand Down
2 changes: 2 additions & 0 deletions alonet/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .object_detector_callback import ObjectDetectorCallback
from .metrics_callback import MetricsCallback
from .base_metrics_callback import InstancesBaseMetricsCallback
from .map_metrics_callback import ApMetricsCallback
from .pq_metrics_callback import PQMetricsCallback
154 changes: 154 additions & 0 deletions alonet/callbacks/base_metrics_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import pytorch_lightning as pl
import aloscene
from alonet import metrics
from pytorch_lightning.utilities import rank_zero_only

# import wandb


class InstancesBaseMetricsCallback(pl.Callback):
def __init__(self, base_metric: metrics, *args, **kwargs):
self.metrics = []
self.base_metric = base_metric
super().__init__(*args, **kwargs)

def inference(self, pl_module: pl.LightningModule, m_outputs: dict, **kwargs):
"""This method will call the `infernece` method of the module's model and will expect to receive the
predicted boxes2D and/or Masks.

Parameters
----------
pl_module : pl.LightningModule
Pytorch lighting module with inference function
m_outputs : dict
Forward outputs

Returns
-------
:mod:`~aloscene.bounding_boxes_2d`, :mod:`~aloscene.Mask`
Boxes and masks predicted from inference function

Notes
-----
If `m_outputs` does not contain "pred_masks" attribute, a [None]*B attribute will be returned by default
"""
b_pred_masks = None
if "pred_masks" in m_outputs:
b_pred_boxes, b_pred_masks = pl_module.inference(m_outputs, **kwargs)
else:
b_pred_boxes = pl_module.inference(m_outputs, **kwargs)
if not isinstance(m_outputs, list):
b_pred_boxes = [b_pred_boxes]
b_pred_masks = [b_pred_masks]
elif b_pred_masks is None:
b_pred_masks = [None] * len(b_pred_boxes)
return b_pred_boxes, b_pred_masks

@rank_zero_only
def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: dict,
batch: list,
batch_idx: int,
dataloader_idx: int,
):
"""Method call after each validation batch. This class is a pytorch lightning callback, therefore
this method will by automaticly call by pl.

This method will call the `infernece` method of the module's model and will expect to receive the
predicted boxes2D and/or Masks. Theses elements will be aggregate to compute the different metrics in the
`on_validation_end` method.
The infernece method will be call using the `m_outputs` key from the outputs dict. If `m_outputs` is a list,
then the list will be consider as an temporal list. Therefore, this callback will aggregate the prediction
for each element of the sequence and will log the final results with the timestep prefix val/t/ instead of
simply /val/

Parameters
----------
trainer: pl.Trainer
Pytorch lightning trainer
pl_module: pl.LightningModule
Pytorch lightning module. The "m_outputs" key is expected for this this callback to work properly.
outputs:
Training/Validation step outputs of the pl.LightningModule class.
batch: list
Batch comming from the dataloader. Usually, a list of frame.
batch_idx: int
Id the batch
dataloader_idx: int
Dataloader batch ID.
"""
if isinstance(batch, list): # Resize frames for mask procedure
batch = batch[0].batch_list(batch)

b_pred_boxes, b_pred_masks = self.inference(pl_module, outputs["m_outputs"])
is_temporal = isinstance(outputs["m_outputs"], list)
for b, (t_pred_boxes, t_pred_masks) in enumerate(zip(b_pred_boxes, b_pred_masks)):

# Retrieve the matching GT boxes at the same time step
t_gt_boxes = batch[b].boxes2d
t_gt_masks = batch[b].segmentation

if not is_temporal:
t_gt_boxes = [t_gt_boxes]
t_gt_masks = [t_gt_masks]

if t_pred_masks is None:
t_pred_masks = [None] * len(t_gt_masks)

# Add the samples to metrics for each batch of the current sequence
for t, (gt_boxes, pred_boxes, gt_masks, pred_masks) in enumerate(
zip(t_gt_boxes, t_pred_boxes, t_gt_masks, t_pred_masks)
):
if t + 1 > len(self.metrics):
self.metrics.append(self.base_metric())
self.add_sample(self.metrics[t], pred_boxes, gt_boxes, pred_masks, gt_masks)

@rank_zero_only
def add_sample(
self,
base_metric: metrics,
pred_boxes: aloscene.BoundingBoxes2D,
gt_boxes: aloscene.BoundingBoxes2D,
pred_masks: aloscene.Mask = None,
gt_masks: aloscene.Mask = None,
):
"""Add a smaple to some `alonet.metrics` class. One might want to inhert this method
to edit the `pred_boxes` and `gt_boxes` boxes before to add them to the ApMetrics class.

Parameters
----------
ap_metrics: Union[:mod:`~alonet.metrics.ApMetrics`, :mod:`~alonet.metrics.PQMetrics`
ApMetrics intance.
pred_boxes: :mod:`~aloscene.BoundingBoxes2D`
Predicted boxes2D.
gt_boxes: :mod:`~aloscene.BoundingBoxes2D`
GT boxes2d.
pred_masks: :mod:`~aloscene.Mask`
Predicted Masks for segmentation task
gt_masks: :mod:`~aloscene.Mask`
GT masks in segmentation task.
"""
base_metric.add_sample(p_bbox=pred_boxes, t_bbox=gt_boxes, p_mask=pred_masks, t_mask=gt_masks)

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
"""Method call at the end of each validation epoch. The method will use all the aggregate
data over the epoch to log the final metrics on wandb.
This class is a pytorch lightning callback, therefore this method will by automaticly call by pl.

This method is currently a WIP since some metrics are not logged due to some wandb error when loading
Table.

Parameters
----------
trainer: pl.Trainer
Pytorch lightning trainer
pl_module: pl.LightningModule
Pytorch lightning module
"""
if trainer.logger is None:
return
raise Exception("To inhert in a child class")
Loading