Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Semi-SL Instance Segmentation #2444

Merged
merged 46 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
aa6d593
added semisl MT. Loss not working.
kprokofi Jun 20, 2023
8e62c9a
added recipie. Unbiased teacher works
kprokofi Jun 22, 2023
fff61f7
added MT
kprokofi Jun 22, 2023
fd12d4f
exps contin
kprokofi Jun 26, 2023
b29cf99
proceed with experiments
kprokofi Jun 27, 2023
b06c0a7
fix errors in forward
kprokofi Jun 29, 2023
219fbf0
change hyperparams. Add clip for testing
kprokofi Jun 26, 2023
50793ba
some exps
kprokofi Jun 28, 2023
591aede
change hyperparams
kprokofi Jun 28, 2023
33ca7dd
added per class thrsh
kprokofi Jul 5, 2023
aae3a95
minor:
kprokofi Jul 5, 2023
34e726c
exps
kprokofi Jul 4, 2023
5179d4f
add switching parameter for thrsh
kprokofi Jul 4, 2023
8d5d467
din thrsh
kprokofi Jul 6, 2023
5b15d7b
added DEMA
kprokofi Jul 19, 2023
de2e103
added dinam thrsh
kprokofi Jul 26, 2023
31fa97a
removed dinam
kprokofi Aug 2, 2023
46b525f
final round exps
kprokofi Aug 7, 2023
a4c1fa1
added MT and semi-sl for ResNet
kprokofi Aug 21, 2023
4d4f34f
added semisl stage. Remove old otx
kprokofi Aug 21, 2023
673fe99
training launches. Merged code with OD task.
kprokofi Aug 22, 2023
7bd3b8b
fix pre-commit
kprokofi Aug 22, 2023
8ae9b92
added tests for Semi-SL IS
kprokofi Aug 23, 2023
9fa18e0
fix detection resolution
kprokofi Aug 23, 2023
62a6006
added unit test for MT
kprokofi Aug 23, 2023
7460917
overwrite iter params in semi-sl config. Return configuration.ymal back
kprokofi Aug 23, 2023
911c448
added semisl for effnet. Hovewer it still doesn't work
kprokofi Aug 24, 2023
271aa6b
changed teacher forward method. Fixed pre-commit
kprokofi Aug 29, 2023
96f5d70
fix unit tests
kprokofi Aug 30, 2023
9e64f5d
fixed detection issues. Moved data pipeline
kprokofi Aug 30, 2023
0a3fbee
minor
kprokofi Aug 30, 2023
c82ccc7
fixed det unit test configure
kprokofi Aug 30, 2023
b55f350
rename file
kprokofi Aug 30, 2023
d1a100b
Merge branch 'kp/semisl_instance_seg' of https://github.com/openvinot…
kprokofi Aug 30, 2023
ebf324e
revert detection scaling back
kprokofi Aug 30, 2023
4e63859
rename semisl data
kprokofi Aug 30, 2023
65e48d2
some changes in unit test for focal loss
kprokofi Aug 30, 2023
cbb181e
fixed pre-commit. returned incremental part back
kprokofi Sep 1, 2023
600cb94
rename selfsl in semisl
kprokofi Sep 1, 2023
dcb1e3e
rename MeanTeacherHook
kprokofi Sep 1, 2023
8f50a29
return yolox data_pipeline
kprokofi Sep 1, 2023
1f9f6ea
fix pre-commit
kprokofi Sep 4, 2023
2b2e910
added one more unit test
kprokofi Sep 4, 2023
31760d2
fix pre-commit
kprokofi Sep 4, 2023
15d7e5c
reply comments
kprokofi Sep 5, 2023
dc245d1
Merge branch 'develop' into kp/semisl_instance_seg
kprokofi Sep 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ def __init__(self, momentum=0.0002, epoch_momentum=0.0, interval=1, **kwargs):
self.epoch_momentum = epoch_momentum
self.interval = interval

def before_run(self, runner):
"""To resume model with it's ema parameters more friendly.

Register ema parameter as ``named_buffer`` to model
"""
if is_module_wrapper(runner.model):
model = runner.model.module.model_s if hasattr(runner.model.module, "model_s") else runner.model.module
else:
model = runner.model.model_s if hasattr(runner.model, "model_s") else runner.model
self.param_ema_buffer = {}
self.model_parameters = dict(model.named_parameters(recurse=True))
for name, value in self.model_parameters.items():
# "." is not allowed in module's buffer name
buffer_name = f"ema_{name.replace('.', '_')}"
self.param_ema_buffer[name] = buffer_name
model.register_buffer(buffer_name, value.data.clone())
self.model_buffers = dict(model.named_buffers(recurse=True))
if self.checkpoint is not None:
runner.resume(self.checkpoint)

def before_train_epoch(self, runner):
"""Update the momentum."""
if self.epoch_momentum > 0.0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def before_run(self, runner):

def before_train_epoch(self, runner):
"""Momentum update."""
if runner.epoch == self.start_epoch:
if runner.epoch + 1 == self.start_epoch:
self._copy_model()
self.enabled = True

Expand Down Expand Up @@ -110,21 +110,24 @@ def _get_model(self, runner):
def _copy_model(self):
with torch.no_grad():
for name, src_param in self.src_params.items():
dst_param = self.dst_params[name]
dst_param.data.copy_(src_param.data)
if not name.startswith("ema_"):
dst_param = self.dst_params[name]
dst_param.data.copy_(src_param.data)

def _ema_model(self):
momentum = min(self.momentum, 1.0)
with torch.no_grad():
for name, src_param in self.src_params.items():
dst_param = self.dst_params[name]
dst_param.data.copy_(dst_param.data * (1 - momentum) + src_param.data * momentum)
if not name.startswith("ema_"):
dst_param = self.dst_params[name]
dst_param.data.copy_(dst_param.data * (1 - momentum) + src_param.data * momentum)

def _diff_model(self):
diff_sum = 0.0
with torch.no_grad():
for name, src_param in self.src_params.items():
dst_param = self.dst_params[name]
diff = ((src_param - dst_param) ** 2).sum()
diff_sum += diff
if not name.startswith("ema_"):
dst_param = self.dst_params[name]
diff = ((src_param - dst_param) ** 2).sum()
diff_sum += diff
return diff_sum
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,24 @@
class UnbiasedTeacherHook(DualModelEMAHook):
kprokofi marked this conversation as resolved.
Show resolved Hide resolved
"""UnbiasedTeacherHook for semi-supervised learnings."""

def __init__(self, min_pseudo_label_ratio=0.1, **kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.min_pseudo_label_ratio = min_pseudo_label_ratio
self.unlabeled_loss_enabled = False

def before_train_epoch(self, runner):
"""Enable unlabeled loss if over start epoch."""
super().before_train_epoch(runner)

if runner.epoch + 1 < self.start_epoch:
return
if self.unlabeled_loss_enabled:
return

super().before_train_epoch(runner)

average_pseudo_label_ratio = self._get_average_pseudo_label_ratio(runner)
logger.info(f"avr_ps_ratio: {average_pseudo_label_ratio}")
if average_pseudo_label_ratio > self.min_pseudo_label_ratio:
self._get_model(runner).enable_unlabeled_loss()
self.unlabeled_loss_enabled = True
logger.info("---------- Enabled unlabeled loss")
self._get_model(runner).enable_unlabeled_loss(True)
self.unlabeled_loss_enabled = True
logger.info("---------- Enabled unlabeled loss and EMA smoothing")

def after_train_iter(self, runner):
"""Update ema parameter every self.interval iterations."""
Expand All @@ -46,7 +44,6 @@ def after_train_iter(self, runner):

if runner.epoch + 1 < self.start_epoch or self.unlabeled_loss_enabled is False:
# Just copy parameters before enabled
self._copy_model()
return

# EMA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from .custom_vfnet_detector import CustomVFNet
from .custom_yolox_detector import CustomYOLOX
from .l2sp_detector_mixin import L2SPDetectorMixin
from .mean_teacher import MeanTeacher
from .sam_detector_mixin import SAMDetectorMixin
from .unbiased_teacher import UnbiasedTeacher

__all__ = [
"CustomATSS",
Expand All @@ -27,6 +27,6 @@
"CustomYOLOX",
"L2SPDetectorMixin",
"SAMDetectorMixin",
"UnbiasedTeacher",
"CustomMaskRCNNTileOptimized",
"MeanTeacher",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
"""UnbiasedTeacher Class for mmdetection detectors."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import copy
import functools

import numpy as np
import torch
from mmdet.core import bbox2result, bbox2roi
from mmdet.core.mask.structures import BitmapMasks
from mmdet.models import DETECTORS, build_detector
from mmdet.models.detectors import BaseDetector

from otx.algorithms.common.utils.logger import get_logger

from .sam_detector_mixin import SAMDetectorMixin

logger = get_logger()

# TODO: Need to fix pylint issues
# pylint: disable=abstract-method, too-many-locals, unused-argument


@DETECTORS.register_module()
class MeanTeacher(SAMDetectorMixin, BaseDetector):
"""Mean teacher framework for detection and instance segmentation."""
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
arch_type,
unlabeled_loss_weights={"cls": 1.0, "bbox": 1.0, "mask": 1.0},
pseudo_conf_thresh=0.7,
bg_loss_weight=-1.0,
min_pseudo_label_ratio=0.0,
**kwargs
):
super().__init__()
self.unlabeled_loss_weights = unlabeled_loss_weights
self.pseudo_conf_thresh = pseudo_conf_thresh
self.bg_loss_weight = bg_loss_weight
self.min_pseudo_label_ratio = min_pseudo_label_ratio
cfg = kwargs.copy()
cfg["type"] = arch_type
self.model_s = build_detector(cfg)
self.model_t = copy.deepcopy(self.model_s)
# warmup for first epochs
self.enable_unlabeled_loss(False)

# Hooks for super_type transparent weight load/save
self._register_state_dict_hook(self.state_dict_hook)
self._register_load_state_dict_pre_hook(functools.partial(self.load_state_dict_pre_hook, self))

def extract_feat(self, imgs):
"""Extract features for UnbiasedTeacher."""
return self.model_s.extract_feat(imgs)

def simple_test(self, img, img_metas, **kwargs):
"""Test from img with UnbiasedTeacher."""
return self.model_s.simple_test(img, img_metas, **kwargs)

def aug_test(self, imgs, img_metas, **kwargs):
"""Aug Test from img with UnbiasedTeacher."""
return self.model_s.aug_test(imgs, img_metas, **kwargs)

def forward_dummy(self, img, **kwargs):
"""Dummy forward function for UnbiasedTeacher."""
return self.model_s.forward_dummy(img, **kwargs)

def enable_unlabeled_loss(self, mode=True):
"""Enable function for UnbiasedTeacher unlabeled loss."""
self.unlabeled_loss_enabled = mode

def forward_teacher(self, img, img_metas):
"""Method to extract predictions (pseudo labeles) from teacher."""
x = self.model_t.extract_feat(img)
proposal_list = self.model_t.rpn_head.simple_test_rpn(x, img_metas)

det_bboxes, det_labels = self.model_t.roi_head.simple_test_bboxes(
x, img_metas, proposal_list, self.model_t.test_cfg.rcnn, rescale=False
)

bbox_results = [
bbox2result(det_bboxes[i], det_labels[i], self.model_t.roi_head.bbox_head.num_classes)
for i in range(len(det_bboxes))
]

if not self.model_t.with_mask:
return bbox_results
else:
ori_shapes = tuple(meta["ori_shape"] for meta in img_metas)
scale_factors = tuple(meta["scale_factor"] for meta in img_metas)

num_imgs = len(det_bboxes)
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
segm_results = [
[[] for _ in range(self.model_t.roi_head.mask_head.num_classes)] for _ in range(num_imgs)
]
else:
_bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))]
mask_rois = bbox2roi(_bboxes)
mask_results = self.model_t.roi_head._mask_forward(x, mask_rois)
mask_pred = mask_results["mask_pred"]
# split batch mask prediction back to each image
num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)

# apply mask post-processing to each image individually
segm_results = []
for i in range(num_imgs):
if det_bboxes[i].shape[0] == 0:
segm_results.append([[] for _ in range(self.model_t.roi_head.mask_head.num_classes)])
else:
segm_result = self.model_t.roi_head.mask_head.get_scaled_seg_masks(
mask_preds[i],
_bboxes[i],
det_labels[i],
self.model_t.test_cfg.rcnn,
ori_shapes[i],
scale_factors[i],
rescale=False,
)
segm_results.append(segm_result)

return list(zip(bbox_results, segm_results))

def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_masks=None, gt_bboxes_ignore=None, **kwargs):
"""Forward function for UnbiasedTeacher."""
losses = {}
# Supervised loss
# TODO: check img0 only option (which is common for mean teacher method)
forward_train = functools.partial(
self.model_s.forward_train,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=(gt_bboxes_ignore if gt_bboxes_ignore else None),
)
if self.model_s.with_mask:
sl_losses = forward_train(gt_masks=gt_masks)
else:
sl_losses = forward_train()
losses.update(sl_losses)

if not self.unlabeled_loss_enabled:
return losses

# Pseudo labels from teacher
ul_args = kwargs.get("extra_0", {})
ul_img = ul_args.get("img")
ul_img0 = ul_args.get("img0")
ul_img_metas = ul_args.get("img_metas")
if ul_img is None:
return losses
with torch.no_grad():
if self.model_t.with_mask:
teacher_outputs = self.forward_teacher(ul_img0, ul_img_metas)
else:
teacher_outputs = self.model_t.forward_test([ul_img0], [ul_img_metas], rescale=False)
current_device = ul_img0[0].device
pseudo_bboxes, pseudo_labels, pseudo_masks, pseudo_ratio = self.generate_pseudo_labels(
teacher_outputs, device=current_device, img_meta=ul_img_metas, **kwargs
)
losses.update(ps_ratio=torch.tensor([pseudo_ratio], device=current_device))

# Unsupervised loss
# Compute only if min_pseudo_label_ratio is reached
if pseudo_ratio >= self.min_pseudo_label_ratio:
if self.bg_loss_weight >= 0.0:
self.model_s.bbox_head.bg_loss_weight = self.bg_loss_weight
if self.model_t.with_mask:
ul_losses = self.model_s.forward_train(
ul_img, ul_img_metas, pseudo_bboxes, pseudo_labels, gt_masks=pseudo_masks
)
else:
ul_losses = self.model_s.forward_train(ul_img, ul_img_metas, pseudo_bboxes, pseudo_labels)

if self.bg_loss_weight >= 0.0:
self.model_s.bbox_head.bg_loss_weight = -1.0

for ul_loss_name in ul_losses.keys():
if ul_loss_name.startswith("loss_"):
ul_loss = ul_losses[ul_loss_name]
target_loss = ul_loss_name.split("_")[-1]
if self.unlabeled_loss_weights[target_loss] == 0:
continue
self._update_unlabeled_loss(losses, ul_loss, ul_loss_name, self.unlabeled_loss_weights[target_loss])
return losses

def generate_pseudo_labels(self, teacher_outputs, img_meta, **kwargs):
"""Generate pseudo label for UnbiasedTeacher."""
device = kwargs.pop("device")
all_pseudo_bboxes = []
all_pseudo_labels = []
all_pseudo_masks = []
num_all_bboxes = 0
num_all_pseudo = 0
for i, teacher_bboxes_labels in enumerate(teacher_outputs):
image_shape = img_meta[i]["img_shape"][:-1]
pseudo_bboxes = []
pseudo_labels = []
pseudo_masks = []
if self.model_t.with_mask:
teacher_bboxes_labels = zip(*teacher_bboxes_labels)
for label, teacher_bboxes_masks in enumerate(teacher_bboxes_labels):
if self.model_t.with_mask:
teacher_bboxes = teacher_bboxes_masks[0]
teacher_masks = teacher_bboxes_masks[1]
else:
teacher_bboxes = teacher_bboxes_masks
confidences = teacher_bboxes[:, -1]
pseudo_indices = confidences > self.pseudo_conf_thresh
pseudo_bboxes.append(teacher_bboxes[pseudo_indices, :4]) # model output: [x y w h conf]
pseudo_labels.append(np.full([sum(pseudo_indices)], label))
if self.model_t.with_mask:
if np.any(pseudo_indices):
teacher_masks = [np.expand_dims(mask, 0) for mask in teacher_masks]
pseudo_masks.append(np.concatenate(teacher_masks)[pseudo_indices])
else:
pseudo_masks.append(np.array([]).reshape(0, *image_shape))

num_all_bboxes += teacher_bboxes.shape[0]
if len(pseudo_bboxes):
num_all_pseudo += pseudo_bboxes[-1].shape[0]

if len(pseudo_bboxes) > 0:
all_pseudo_bboxes.append(torch.from_numpy(np.concatenate(pseudo_bboxes)).to(device))
all_pseudo_labels.append(torch.from_numpy(np.concatenate(pseudo_labels)).to(device))
if self.model_t.with_mask:
all_pseudo_masks.append(BitmapMasks(np.concatenate(pseudo_masks), *image_shape))

pseudo_ratio = float(num_all_pseudo) / num_all_bboxes if num_all_bboxes > 0 else 0.0
return all_pseudo_bboxes, all_pseudo_labels, all_pseudo_masks, pseudo_ratio

@staticmethod
def _update_unlabeled_loss(sum_loss, loss, loss_name, weight):
if isinstance(loss, list):
sum_loss[loss_name + "_ul"] = [cur_loss * weight for cur_loss in loss]
else:
sum_loss[loss_name + "_ul"] = loss * weight

@staticmethod
def state_dict_hook(module, state_dict, prefix, *args, **kwargs): # pylint: disable=unused-argument
"""Redirect student model as output state_dict (teacher as auxilliary)."""
logger.info("----------------- MeanTeacherSegmentor.state_dict_hook() called")
for key in list(state_dict.keys()):
value = state_dict.pop(key)
if not prefix or key.startswith(prefix):
key = key.replace(prefix, "", 1)
if key.startswith("model_s."):
key = key.replace("model_s.", "", 1)
elif key.startswith("model_t."):
continue
key = prefix + key
state_dict[key] = value
return state_dict

@staticmethod
def load_state_dict_pre_hook(module, state_dict, *args, **kwargs): # pylint: disable=unused-argument
"""Redirect input state_dict to teacher model."""
logger.info("----------------- MeanTeacherSegmentor.load_state_dict_pre_hook() called")
for key in list(state_dict.keys()):
value = state_dict.pop(key)
state_dict["model_s." + key] = value
state_dict["model_t." + key] = value
Loading