Skip to content

Commit 6af3a80

Browse files
committed
code arrangement
1 parent 6aa1a0c commit 6af3a80

File tree

1 file changed

+15
-64
lines changed

1 file changed

+15
-64
lines changed

ote_cli/ote_cli/utils/hpo.py

+15-64
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@
2424
from math import ceil
2525
from os import path as osp
2626
from pathlib import Path
27-
from typing import Callable, Optional
27+
from typing import Optional
2828

2929
import torch
3030
import yaml
31+
from mmcv.utils import ConfigDict # pylint: disable=import-error
3132
from ote_sdk.configuration.helper import create
3233
from ote_sdk.entities.model import ModelEntity
3334
from ote_sdk.entities.model_template import TaskType
3435
from ote_sdk.entities.subset import Subset
3536
from ote_sdk.entities.task_environment import TaskEnvironment
3637
from ote_sdk.entities.train_parameters import TrainParameters, UpdateProgressCallback
37-
from mmcv.utils import ConfigDict
3838

3939
from ote_cli.datasets import get_dataset_class
4040
from ote_cli.utils.importing import get_impl_class
@@ -280,7 +280,7 @@ def run_hpo_trainer(
280280

281281
# make callback to report score to hpopt every epoch
282282
train_param = TrainParameters(
283-
False, HpoCallback(hp_config, hp_config["metric"], task), ModelSavedCallback()
283+
False, HpoCallback(hp_config, hp_config["metric"], task), None
284284
)
285285

286286
task.train(dataset=dataset, output_model=output_model, train_parameters=train_param)
@@ -349,19 +349,6 @@ def set_resume_path_to_config(self, resume_path):
349349

350350
def prepare_hpo(self, hp_config):
351351
"""update config of the each task framework for the HPO"""
352-
# if _is_cls_framework_task(self._task_type):
353-
# # pylint: disable=attribute-defined-outside-init
354-
# self._scratch_space = _get_hpo_trial_workdir(hp_config)
355-
# self._cfg.data.save_dir = self._scratch_space
356-
# self._cfg.model.save_all_chkpts = True
357-
# elif _is_det_framework_task(self._task_type) or _is_seg_framework_task(
358-
# self._task_type
359-
# ):
360-
# self._config.work_dir = _get_hpo_trial_workdir(hp_config)
361-
# self._config.checkpoint_config["max_keep_ckpts"] = (
362-
# hp_config["iterations"] + 10
363-
# )
364-
# self._config.checkpoint_config["interval"] = 1
365352
if (
366353
_is_cls_framework_task(self._task_type)
367354
or _is_det_framework_task(self._task_type)
@@ -373,7 +360,9 @@ def prepare_hpo(self, hp_config):
373360
)
374361
)
375362
self.set_override_configurations(cfg)
376-
self._output_path = _get_hpo_trial_workdir(hp_config)
363+
self._output_path = _get_hpo_trial_workdir( # pylint: disable=attribute-defined-outside-init
364+
hp_config
365+
)
377366

378367
def prepare_saving_initial_weight(self, save_path):
379368
"""add a hook which saves initial model weight before training"""
@@ -382,11 +371,10 @@ def prepare_saving_initial_weight(self, save_path):
382371
or _is_det_framework_task(task_type)
383372
or _is_seg_framework_task(task_type)
384373
):
385-
cfg = { "custom_hooks" : [
386-
ConfigDict(
387-
type="SaveInitialWeightHook",
388-
save_path=save_path
389-
)]
374+
cfg = {
375+
"custom_hooks": [
376+
ConfigDict(type="SaveInitialWeightHook", save_path=save_path)
377+
]
390378
}
391379
self.set_override_configurations(cfg)
392380
else:
@@ -395,7 +383,6 @@ def prepare_saving_initial_weight(self, save_path):
395383
"initial weight should be saved before HPO."
396384
)
397385

398-
399386
return HpoTrainTask
400387

401388

@@ -446,7 +433,8 @@ def _load_hpopt_config(file_path):
446433
def _get_best_model_weight_path(hpo_dir: str, trial_num: str, task_type: TaskType):
447434
"""Return best model weight from HPO trial directory"""
448435
best_weight_path = None
449-
if (_is_cls_framework_task(task_type)
436+
if (
437+
_is_cls_framework_task(task_type)
450438
or _is_det_framework_task(task_type)
451439
or _is_seg_framework_task(task_type)
452440
):
@@ -457,7 +445,7 @@ def _get_best_model_weight_path(hpo_dir: str, trial_num: str, task_type: TaskTyp
457445
break
458446
elif _is_anomaly_framework_task(task_type):
459447
# TODO need to implement later
460-
best_weight_path = ""
448+
pass
461449

462450
return best_weight_path
463451

@@ -500,14 +488,6 @@ def __call__(self, progress: float, score: Optional[float] = None):
500488
self.hpo_task.cancel_training()
501489

502490

503-
class ModelSavedCallback(Callable):
504-
def __call__(self, path: str, event: str):
505-
if event == "initialized":
506-
print(
507-
f"********** called model saved callback ({path}, {event}) **************"
508-
)
509-
510-
511491
class HpoManager:
512492
"""Manage overall HPO process"""
513493

@@ -555,7 +535,6 @@ def __init__(
555535
train_dataset_size = len(dataset.get_subset(Subset.TRAINING))
556536
val_dataset_size = len(dataset.get_subset(Subset.VALIDATION))
557537

558-
559538
# make batch size range lower than train set size
560539
env_hp = self.environment.get_hyper_parameters()
561540
if (
@@ -627,27 +606,6 @@ def __init__(
627606
# Prevent each trials from being stopped during warmup stage
628607
batch_size = default_hyper_parameters.get(batch_size_name)
629608
if "min_iterations" not in hpopt_cfg and batch_size is not None:
630-
# if _is_cls_framework_task(task_type):
631-
# with open(
632-
# osp.join(
633-
# osp.dirname(
634-
# self.environment.model_template.model_template_path
635-
# ),
636-
# "model.yaml",
637-
# ),
638-
# "r",
639-
# encoding="utf-8",
640-
# ) as f:
641-
# model_yaml = yaml.safe_load(f)
642-
# if "warmup" in model_yaml["train"]:
643-
# hpopt_arguments["min_iterations"] = ceil(
644-
# model_yaml["train"]["warmup"]
645-
# * batch_size
646-
# / train_dataset_size
647-
# )
648-
# elif _is_det_framework_task(task_type) or _is_seg_framework_task(
649-
# task_type
650-
# ):
651609
if (
652610
_is_cls_framework_task(task_type)
653611
or _is_det_framework_task(task_type)
@@ -798,14 +756,6 @@ def run(self):
798756
for key, val in self.fixed_hp.items():
799757
best_config[key] = val
800758

801-
# TODO: is it needed here?
802-
# # finetune stage resumes hpo trial, so warmup isn't needed
803-
# if task_type == TaskType.DETECTION:
804-
# best_config["learning_parameters.learning_rate_warmup_iters"] = 0
805-
# if task_type == TaskType.SEGMENTATION:
806-
# best_config["learning_parameters.learning_rate_fixed_iters"] = 0
807-
# best_config["learning_parameters.learning_rate_warmup_iters"] = 0
808-
809759
hyper_parameters = self.environment.get_hyper_parameters()
810760
HpoManager.set_hyperparameter(hyper_parameters, best_config)
811761

@@ -866,7 +816,8 @@ def get_num_full_iterations(environment):
866816

867817
task_type = environment.model_template.task_type
868818
params = environment.get_hyper_parameters()
869-
if (_is_cls_framework_task(task_type)
819+
if (
820+
_is_cls_framework_task(task_type)
870821
or _is_det_framework_task(task_type)
871822
or _is_seg_framework_task(task_type)
872823
):

0 commit comments

Comments
 (0)