Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Semi-SL Instance Segmentation #2444

Merged
merged 46 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
aa6d593
added semisl MT. Loss not working.
kprokofi Jun 20, 2023
8e62c9a
added recipie. Unbiased teacher works
kprokofi Jun 22, 2023
fff61f7
added MT
kprokofi Jun 22, 2023
fd12d4f
exps contin
kprokofi Jun 26, 2023
b29cf99
proceed with experiments
kprokofi Jun 27, 2023
b06c0a7
fix errors in forward
kprokofi Jun 29, 2023
219fbf0
change hyperparams. Add clip for testing
kprokofi Jun 26, 2023
50793ba
some exps
kprokofi Jun 28, 2023
591aede
change hyperparams
kprokofi Jun 28, 2023
33ca7dd
added per class thrsh
kprokofi Jul 5, 2023
aae3a95
minor:
kprokofi Jul 5, 2023
34e726c
exps
kprokofi Jul 4, 2023
5179d4f
add switching parameter for thrsh
kprokofi Jul 4, 2023
8d5d467
din thrsh
kprokofi Jul 6, 2023
5b15d7b
added DEMA
kprokofi Jul 19, 2023
de2e103
added dinam thrsh
kprokofi Jul 26, 2023
31fa97a
removed dinam
kprokofi Aug 2, 2023
46b525f
final round exps
kprokofi Aug 7, 2023
a4c1fa1
added MT and semi-sl for ResNet
kprokofi Aug 21, 2023
4d4f34f
added semisl stage. Remove old otx
kprokofi Aug 21, 2023
673fe99
training launches. Merged code with OD task.
kprokofi Aug 22, 2023
7bd3b8b
fix pre-commit
kprokofi Aug 22, 2023
8ae9b92
added tests for Semi-SL IS
kprokofi Aug 23, 2023
9fa18e0
fix detection resolution
kprokofi Aug 23, 2023
62a6006
added unit test for MT
kprokofi Aug 23, 2023
7460917
overwrite iter params in semi-sl config. Return configuration.ymal back
kprokofi Aug 23, 2023
911c448
added semisl for effnet. Hovewer it still doesn't work
kprokofi Aug 24, 2023
271aa6b
changed teacher forward method. Fixed pre-commit
kprokofi Aug 29, 2023
96f5d70
fix unit tests
kprokofi Aug 30, 2023
9e64f5d
fixed detection issues. Moved data pipeline
kprokofi Aug 30, 2023
0a3fbee
minor
kprokofi Aug 30, 2023
c82ccc7
fixed det unit test configure
kprokofi Aug 30, 2023
b55f350
rename file
kprokofi Aug 30, 2023
d1a100b
Merge branch 'kp/semisl_instance_seg' of https://github.com/openvinot…
kprokofi Aug 30, 2023
ebf324e
revert detection scaling back
kprokofi Aug 30, 2023
4e63859
rename semisl data
kprokofi Aug 30, 2023
65e48d2
some changes in unit test for focal loss
kprokofi Aug 30, 2023
cbb181e
fixed pre-commit. returned incremental part back
kprokofi Sep 1, 2023
600cb94
rename selfsl in semisl
kprokofi Sep 1, 2023
dcb1e3e
rename MeanTeacherHook
kprokofi Sep 1, 2023
8f50a29
return yolox data_pipeline
kprokofi Sep 1, 2023
1f9f6ea
fix pre-commit
kprokofi Sep 4, 2023
2b2e910
added one more unit test
kprokofi Sep 4, 2023
31760d2
fix pre-commit
kprokofi Sep 4, 2023
15d7e5c
reply comments
kprokofi Sep 5, 2023
dc245d1
Merge branch 'develop' into kp/semisl_instance_seg
kprokofi Sep 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/otx/algorithms/common/adapters/mmcv/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .ib_loss_hook import IBLossHook
from .logger_hook import LoggerReplaceHook, OTXLoggerHook
from .loss_dynamics_tracking_hook import LossDynamicsTrackingHook
from .mean_teacher_hook import MeanTeacherHook
from .mem_cache_hook import MemCacheHook
from .model_ema_v2_hook import ModelEmaV2Hook
from .no_bias_decay_hook import NoBiasDecayHook
Expand All @@ -51,7 +52,6 @@
from .semisl_cls_hook import SemiSLClsHook
from .task_adapt_hook import TaskAdaptHook
from .two_crop_transform_hook import TwoCropTransformHook
from .unbiased_teacher_hook import UnbiasedTeacherHook

__all__ = [
"AdaptiveRepeatDataHook",
Expand Down Expand Up @@ -87,7 +87,7 @@
"SemiSLClsHook",
"TaskAdaptHook",
"TwoCropTransformHook",
"UnbiasedTeacherHook",
"MeanTeacherHook",
"MemCacheHook",
"LossDynamicsTrackingHook",
]
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ def __init__(self, momentum=0.0002, epoch_momentum=0.0, interval=1, **kwargs):
self.epoch_momentum = epoch_momentum
self.interval = interval

def before_run(self, runner):
"""To resume model with it's ema parameters more friendly.

Register ema parameter as ``named_buffer`` to model
"""
if is_module_wrapper(runner.model):
model = runner.model.module.model_s if hasattr(runner.model.module, "model_s") else runner.model.module
else:
model = runner.model.model_s if hasattr(runner.model, "model_s") else runner.model
self.param_ema_buffer = {}
self.model_parameters = dict(model.named_parameters(recurse=True))
for name, value in self.model_parameters.items():
# "." is not allowed in module's buffer name
buffer_name = f"ema_{name.replace('.', '_')}"
self.param_ema_buffer[name] = buffer_name
model.register_buffer(buffer_name, value.data.clone())
self.model_buffers = dict(model.named_buffers(recurse=True))
if self.checkpoint is not None:
runner.resume(self.checkpoint)

def before_train_epoch(self, runner):
"""Update the momentum."""
if self.epoch_momentum > 0.0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def before_run(self, runner):

def before_train_epoch(self, runner):
"""Momentum update."""
if runner.epoch == self.start_epoch:
if runner.epoch + 1 == self.start_epoch:
self._copy_model()
self.enabled = True

Expand Down Expand Up @@ -110,21 +110,24 @@ def _get_model(self, runner):
def _copy_model(self):
with torch.no_grad():
for name, src_param in self.src_params.items():
dst_param = self.dst_params[name]
dst_param.data.copy_(src_param.data)
if not name.startswith("ema_"):
dst_param = self.dst_params[name]
dst_param.data.copy_(src_param.data)

def _ema_model(self):
momentum = min(self.momentum, 1.0)
with torch.no_grad():
for name, src_param in self.src_params.items():
dst_param = self.dst_params[name]
dst_param.data.copy_(dst_param.data * (1 - momentum) + src_param.data * momentum)
if not name.startswith("ema_"):
dst_param = self.dst_params[name]
dst_param.data.copy_(dst_param.data * (1 - momentum) + src_param.data * momentum)

def _diff_model(self):
diff_sum = 0.0
with torch.no_grad():
for name, src_param in self.src_params.items():
dst_param = self.dst_params[name]
diff = ((src_param - dst_param) ** 2).sum()
diff_sum += diff
if not name.startswith("ema_"):
dst_param = self.dst_params[name]
diff = ((src_param - dst_param) ** 2).sum()
diff_sum += diff
return diff_sum
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,27 @@


@HOOKS.register_module()
class UnbiasedTeacherHook(DualModelEMAHook):
"""UnbiasedTeacherHook for semi-supervised learnings."""
class MeanTeacherHook(DualModelEMAHook):
"""MeanTeacherHook for semi-supervised learnings."""

def __init__(self, min_pseudo_label_ratio=0.1, **kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.min_pseudo_label_ratio = min_pseudo_label_ratio
self.unlabeled_loss_enabled = False

def before_train_epoch(self, runner):
"""Enable unlabeled loss if over start epoch."""
super().before_train_epoch(runner)

if runner.epoch + 1 < self.start_epoch:
return
if self.unlabeled_loss_enabled:
return

super().before_train_epoch(runner)

average_pseudo_label_ratio = self._get_average_pseudo_label_ratio(runner)
logger.info(f"avr_ps_ratio: {average_pseudo_label_ratio}")
if average_pseudo_label_ratio > self.min_pseudo_label_ratio:
self._get_model(runner).enable_unlabeled_loss()
self.unlabeled_loss_enabled = True
logger.info("---------- Enabled unlabeled loss")
self._get_model(runner).enable_unlabeled_loss(True)
self.unlabeled_loss_enabled = True
logger.info("---------- Enabled unlabeled loss and EMA smoothing")

def after_train_iter(self, runner):
"""Update ema parameter every self.interval iterations."""
Expand All @@ -46,7 +44,6 @@ def after_train_iter(self, runner):

if runner.epoch + 1 < self.start_epoch or self.unlabeled_loss_enabled is False:
# Just copy parameters before enabled
self._copy_model()
return

# EMA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from .custom_vfnet_detector import CustomVFNet
from .custom_yolox_detector import CustomYOLOX
from .l2sp_detector_mixin import L2SPDetectorMixin
from .mean_teacher import MeanTeacher
from .sam_detector_mixin import SAMDetectorMixin
from .unbiased_teacher import UnbiasedTeacher

__all__ = [
"CustomATSS",
Expand All @@ -29,6 +29,6 @@
"CustomYOLOX",
"L2SPDetectorMixin",
"SAMDetectorMixin",
"UnbiasedTeacher",
"CustomMaskRCNNTileOptimized",
"MeanTeacher",
]
Loading