|
| 1 | +"""Score reporting callback""" |
| 2 | + |
| 3 | +# Copyright (C) 2020 Intel Corporation |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions |
| 15 | +# and limitations under the License. |
| 16 | + |
| 17 | +from typing import Optional |
| 18 | + |
| 19 | +from ote_sdk.entities.train_parameters import TrainParameters |
| 20 | +from pytorch_lightning import Callback |
| 21 | + |
| 22 | + |
| 23 | +class ScoreReportingCallback(Callback): |
| 24 | + """ |
| 25 | + Callback for reporting score. |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__(self, parameters: Optional[TrainParameters] = None) -> None: |
| 29 | + if parameters is not None: |
| 30 | + self.score_reporting_callback = parameters.update_progress |
| 31 | + else: |
| 32 | + self.score_reporting_callback = None |
| 33 | + |
| 34 | + def on_validation_epoch_end(self, trainer, pl_module): |
| 35 | + """ |
| 36 | + If score exists in trainer.logged_metrics, report the score. |
| 37 | + """ |
| 38 | + if self.score_reporting_callback is not None: |
| 39 | + score = None |
| 40 | + metric = getattr(self.score_reporting_callback, 'metric', None) |
| 41 | + if metric in trainer.logged_metrics: |
| 42 | + score = float(trainer.logged_metrics[metric]) |
| 43 | + self.score_reporting_callback(progress=0, score=score) |
0 commit comments