Skip to content

Commit

Permalink
Add Semi-SL Instance Segmentation (#2444)
Browse files Browse the repository at this point in the history
* added semisl MT. Loss not working.

* added recipie. Unbiased teacher works

* added MT

* exps contin

* proceed with experiments

* fix errors in forward

* change hyperparams. Add clip for testing

* some exps

* change hyperparams

* added per class thrsh

* minor:

* exps

* add switching parameter for thrsh

* din thrsh

* added DEMA

* added dinam thrsh

* removed dinam

* final round exps

* added MT and semi-sl for ResNet

* added semisl stage. Remove old otx

* training launches. Merged code with OD task.

* fix pre-commit

* added tests for Semi-SL IS

* fix detection resolution

* added unit test for MT

* overwrite iter params in semi-sl config. Return configuration.ymal back

* added semisl for effnet. Hovewer it still doesn't work

* changed teacher forward method. Fixed pre-commit

* fix unit tests

* fixed detection issues. Moved data pipeline

* minor

* fixed det unit test configure

* rename file

* revert detection scaling back

* rename semisl data

* some changes in unit test for focal loss

* fixed pre-commit. returned incremental part back

* rename selfsl in semisl

* rename MeanTeacherHook

* return yolox data_pipeline

* fix pre-commit

* added one more unit test

* fix pre-commit

* reply comments
  • Loading branch information
kprokofi authored Sep 6, 2023
1 parent b78d9e5 commit e28026b
Show file tree
Hide file tree
Showing 49 changed files with 1,355 additions and 318 deletions.
4 changes: 2 additions & 2 deletions src/otx/algorithms/common/adapters/mmcv/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .ib_loss_hook import IBLossHook
from .logger_hook import LoggerReplaceHook, OTXLoggerHook
from .loss_dynamics_tracking_hook import LossDynamicsTrackingHook
from .mean_teacher_hook import MeanTeacherHook
from .mem_cache_hook import MemCacheHook
from .model_ema_v2_hook import ModelEmaV2Hook
from .no_bias_decay_hook import NoBiasDecayHook
Expand All @@ -51,7 +52,6 @@
from .semisl_cls_hook import SemiSLClsHook
from .task_adapt_hook import TaskAdaptHook
from .two_crop_transform_hook import TwoCropTransformHook
from .unbiased_teacher_hook import UnbiasedTeacherHook

__all__ = [
"AdaptiveRepeatDataHook",
Expand Down Expand Up @@ -87,7 +87,7 @@
"SemiSLClsHook",
"TaskAdaptHook",
"TwoCropTransformHook",
"UnbiasedTeacherHook",
"MeanTeacherHook",
"MemCacheHook",
"LossDynamicsTrackingHook",
]
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ def __init__(self, momentum=0.0002, epoch_momentum=0.0, interval=1, **kwargs):
self.epoch_momentum = epoch_momentum
self.interval = interval

def before_run(self, runner):
"""To resume model with it's ema parameters more friendly.
Register ema parameter as ``named_buffer`` to model
"""
if is_module_wrapper(runner.model):
model = runner.model.module.model_s if hasattr(runner.model.module, "model_s") else runner.model.module
else:
model = runner.model.model_s if hasattr(runner.model, "model_s") else runner.model
self.param_ema_buffer = {}
self.model_parameters = dict(model.named_parameters(recurse=True))
for name, value in self.model_parameters.items():
# "." is not allowed in module's buffer name
buffer_name = f"ema_{name.replace('.', '_')}"
self.param_ema_buffer[name] = buffer_name
model.register_buffer(buffer_name, value.data.clone())
self.model_buffers = dict(model.named_buffers(recurse=True))
if self.checkpoint is not None:
runner.resume(self.checkpoint)

def before_train_epoch(self, runner):
"""Update the momentum."""
if self.epoch_momentum > 0.0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def before_run(self, runner):

def before_train_epoch(self, runner):
"""Momentum update."""
if runner.epoch == self.start_epoch:
if runner.epoch + 1 == self.start_epoch:
self._copy_model()
self.enabled = True

Expand Down Expand Up @@ -110,21 +110,24 @@ def _get_model(self, runner):
def _copy_model(self):
with torch.no_grad():
for name, src_param in self.src_params.items():
dst_param = self.dst_params[name]
dst_param.data.copy_(src_param.data)
if not name.startswith("ema_"):
dst_param = self.dst_params[name]
dst_param.data.copy_(src_param.data)

def _ema_model(self):
momentum = min(self.momentum, 1.0)
with torch.no_grad():
for name, src_param in self.src_params.items():
dst_param = self.dst_params[name]
dst_param.data.copy_(dst_param.data * (1 - momentum) + src_param.data * momentum)
if not name.startswith("ema_"):
dst_param = self.dst_params[name]
dst_param.data.copy_(dst_param.data * (1 - momentum) + src_param.data * momentum)

def _diff_model(self):
diff_sum = 0.0
with torch.no_grad():
for name, src_param in self.src_params.items():
dst_param = self.dst_params[name]
diff = ((src_param - dst_param) ** 2).sum()
diff_sum += diff
if not name.startswith("ema_"):
dst_param = self.dst_params[name]
diff = ((src_param - dst_param) ** 2).sum()
diff_sum += diff
return diff_sum
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,27 @@


@HOOKS.register_module()
class UnbiasedTeacherHook(DualModelEMAHook):
"""UnbiasedTeacherHook for semi-supervised learnings."""
class MeanTeacherHook(DualModelEMAHook):
"""MeanTeacherHook for semi-supervised learnings."""

def __init__(self, min_pseudo_label_ratio=0.1, **kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.min_pseudo_label_ratio = min_pseudo_label_ratio
self.unlabeled_loss_enabled = False

def before_train_epoch(self, runner):
"""Enable unlabeled loss if over start epoch."""
super().before_train_epoch(runner)

if runner.epoch + 1 < self.start_epoch:
return
if self.unlabeled_loss_enabled:
return

super().before_train_epoch(runner)

average_pseudo_label_ratio = self._get_average_pseudo_label_ratio(runner)
logger.info(f"avr_ps_ratio: {average_pseudo_label_ratio}")
if average_pseudo_label_ratio > self.min_pseudo_label_ratio:
self._get_model(runner).enable_unlabeled_loss()
self.unlabeled_loss_enabled = True
logger.info("---------- Enabled unlabeled loss")
self._get_model(runner).enable_unlabeled_loss(True)
self.unlabeled_loss_enabled = True
logger.info("---------- Enabled unlabeled loss and EMA smoothing")

def after_train_iter(self, runner):
"""Update ema parameter every self.interval iterations."""
Expand All @@ -46,7 +44,6 @@ def after_train_iter(self, runner):

if runner.epoch + 1 < self.start_epoch or self.unlabeled_loss_enabled is False:
# Just copy parameters before enabled
self._copy_model()
return

# EMA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from .custom_vfnet_detector import CustomVFNet
from .custom_yolox_detector import CustomYOLOX
from .l2sp_detector_mixin import L2SPDetectorMixin
from .mean_teacher import MeanTeacher
from .sam_detector_mixin import SAMDetectorMixin
from .unbiased_teacher import UnbiasedTeacher

__all__ = [
"CustomATSS",
Expand All @@ -29,6 +29,6 @@
"CustomYOLOX",
"L2SPDetectorMixin",
"SAMDetectorMixin",
"UnbiasedTeacher",
"CustomMaskRCNNTileOptimized",
"MeanTeacher",
]
Loading

0 comments on commit e28026b

Please sign in to comment.