forked from openvinotoolkit/training_extensions
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmixin.py
144 lines (113 loc) · 5.64 KB
/
mixin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Module defining Mix-in class of SAMClassifier."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
import datumaro as dm
import numpy as np
import pandas as pd
from otx.algorithms.common.utils.logger import get_logger
from otx.api.entities.dataset_item import DatasetItemEntityWithID
from otx.api.entities.datasets import DatasetEntity
from otx.core.data.noisy_label_detection import LossDynamicsTracker, LossDynamicsTrackingMixin
logger = get_logger()
class SAMClassifierMixin:
"""SAM-enabled BaseClassifier mix-in."""
def train_step(self, data, optimizer=None, **kwargs):
"""Saving current batch data to compute SAM gradient."""
self.current_batch = data
return super().train_step(data, optimizer, **kwargs)
class MultiClassClsLossDynamicsTracker(LossDynamicsTracker):
"""Loss dynamics tracker for multi-class classification task."""
def __init__(self) -> None:
super().__init__()
def init_with_otx_dataset(self, otx_dataset: DatasetEntity[DatasetItemEntityWithID]) -> None:
"""DatasetEntity should be injected to the tracker for the initialization."""
otx_labels = otx_dataset.get_labels()
label_categories = dm.LabelCategories.from_iterable([label_entity.name for label_entity in otx_labels])
self.otx_label_map = {label_entity.id_: idx for idx, label_entity in enumerate(otx_labels)}
def _convert_anns(item: DatasetItemEntityWithID):
labels = [
dm.Label(label=self.otx_label_map[label.id_])
for ann in item.get_annotations()
for label in ann.get_labels()
]
return labels
self._export_dataset = dm.Dataset.from_iterable(
[
dm.DatasetItem(
id=item.id_,
subset="train",
media=dm.Image.from_file(path=item.media.path, size=(item.media.height, item.media.width))
if item.media.path
else dm.Image.from_numpy(
data=getattr(item.media, "_Image__data"), size=(item.media.height, item.media.width)
),
annotations=_convert_anns(item),
)
for item in otx_dataset
],
infos={"purpose": "noisy_label_detection", "task": "OTX-MultiClassCls"},
categories={dm.AnnotationType.label: label_categories},
)
super().init_with_otx_dataset(otx_dataset)
def accumulate(self, outputs, iter) -> None:
"""Accumulate training loss dynamics for each training step."""
entity_ids = outputs["entity_ids"]
label_ids = np.squeeze(outputs["label_ids"])
loss_dyns = outputs["loss_dyns"]
for entity_id, label_id, loss_dyn in zip(entity_ids, label_ids, loss_dyns):
self._loss_dynamics[(entity_id, label_id)].append((iter, loss_dyn))
def export(self, output_path: str) -> None:
"""Export loss dynamics statistics to Datumaro format."""
df = pd.DataFrame.from_dict(
{
k: (np.array([iter for iter, _ in arr]), np.array([value for _, value in arr]))
for k, arr in self._loss_dynamics.items()
},
orient="index",
columns=["iters", "loss_dynamics"],
)
for (entity_id, label_id), row in df.iterrows():
item = self._export_dataset.get(entity_id, "train")
for ann in item.annotations:
if isinstance(ann, dm.Label) and ann.label == self.otx_label_map[label_id]:
ann.attributes = row.to_dict()
self._export_dataset.export(output_path, format="datumaro")
class ClsLossDynamicsTrackingMixin(LossDynamicsTrackingMixin):
"""Mix-in to track loss dynamics during training for classification tasks."""
def __init__(self, track_loss_dynamics: bool = False, **kwargs):
if track_loss_dynamics:
if getattr(self, "multilabel", False) or getattr(self, "hierarchical", False):
raise RuntimeError("multilabel or hierarchical tasks are not supported now.")
head_cfg = kwargs.get("head", None)
loss_cfg = head_cfg.get("loss", None)
loss_cfg["reduction"] = "none"
# This should be called after modifying "reduction" config.
super().__init__(**kwargs)
# This should be called after super().__init__(),
# since LossDynamicsTrackingMixin.__init__() creates self._loss_dyns_tracker
self._loss_dyns_tracker = MultiClassClsLossDynamicsTracker()
def train_step(self, data, optimizer=None, **kwargs):
"""The iteration step for training.
If self._track_loss_dynamics = False, just follow BaseClassifier.train_step().
Otherwise, it steps with tracking loss dynamics.
"""
if self.loss_dyns_tracker.initialized:
return self._train_step_with_tracking(data, optimizer, **kwargs)
return super().train_step(data, optimizer, **kwargs)
def _train_step_with_tracking(self, data, optimizer=None, **kwargs):
losses = self(**data)
loss_dyns = losses["loss"].detach().cpu().numpy()
assert not np.isscalar(loss_dyns)
entity_ids = [img_meta["entity_id"] for img_meta in data["img_metas"]]
label_ids = [img_meta["label_id"] for img_meta in data["img_metas"]]
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss,
log_vars=log_vars,
loss_dyns=loss_dyns,
entity_ids=entity_ids,
label_ids=label_ids,
num_samples=len(data["img"].data),
)
return outputs