Skip to content

Commit a4f35cb

Browse files
author
Songki Choi
authored
Merge pull request #1099 from openvinotoolkit/sangdaen/hpo_anomaly_progress
[HPO] Enabling ote_anomalib report score for hpopt
2 parents eed9902 + 073d045 commit a4f35cb

File tree

3 files changed

+51
-3
lines changed

3 files changed

+51
-3
lines changed

external/anomaly/ote_anomalib/callbacks/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818

1919
from .inference import AnomalyInferenceCallback
2020
from .progress import ProgressCallback
21+
from .score_report import ScoreReportingCallback
2122

22-
__all__ = ["AnomalyInferenceCallback", "ProgressCallback"]
23+
__all__ = ["AnomalyInferenceCallback", "ProgressCallback", "ScoreReportingCallback"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Score reporting callback"""
2+
3+
# Copyright (C) 2020 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions
15+
# and limitations under the License.
16+
17+
from typing import Optional
18+
19+
from ote_sdk.entities.train_parameters import TrainParameters
20+
from pytorch_lightning import Callback
21+
22+
23+
class ScoreReportingCallback(Callback):
24+
"""
25+
Callback for reporting score.
26+
"""
27+
28+
def __init__(self, parameters: Optional[TrainParameters] = None) -> None:
29+
if parameters is not None:
30+
self.score_reporting_callback = parameters.update_progress
31+
else:
32+
self.score_reporting_callback = None
33+
34+
def on_validation_epoch_end(self, trainer, pl_module):
35+
"""
36+
If score exists in trainer.logged_metrics, report the score.
37+
"""
38+
if self.score_reporting_callback is not None:
39+
score = None
40+
metric = getattr(self.score_reporting_callback, 'metric', None)
41+
if metric in trainer.logged_metrics:
42+
score = float(trainer.logged_metrics[metric])
43+
self.score_reporting_callback(progress=0, score=score)

external/anomaly/ote_anomalib/train_task.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from anomalib.utils.callbacks import MinMaxNormalizationCallback
2020
from ote_anomalib import AnomalyInferenceTask
21-
from ote_anomalib.callbacks import ProgressCallback
21+
from ote_anomalib.callbacks import ProgressCallback, ScoreReportingCallback
2222
from ote_anomalib.data import OTEAnomalyDataModule
2323
from ote_anomalib.logging import get_logger
2424
from ote_sdk.entities.datasets import DatasetEntity
@@ -61,7 +61,11 @@ def train(
6161
logger.info("Training Configs '%s'", config)
6262

6363
datamodule = OTEAnomalyDataModule(config=config, dataset=dataset, task_type=self.task_type)
64-
callbacks = [ProgressCallback(parameters=train_parameters), MinMaxNormalizationCallback()]
64+
callbacks = [
65+
ProgressCallback(parameters=train_parameters),
66+
MinMaxNormalizationCallback(),
67+
ScoreReportingCallback(parameters=train_parameters)
68+
]
6569

6670
self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks)
6771
self.trainer.fit(model=self.model, datamodule=datamodule)

0 commit comments

Comments
 (0)