24
24
from math import ceil
25
25
from os import path as osp
26
26
from pathlib import Path
27
- from typing import Callable , Optional
27
+ from typing import Optional
28
28
29
29
import torch
30
30
import yaml
31
+ from mmcv .utils import ConfigDict # pylint: disable=import-error
31
32
from ote_sdk .configuration .helper import create
32
33
from ote_sdk .entities .model import ModelEntity
33
34
from ote_sdk .entities .model_template import TaskType
34
35
from ote_sdk .entities .subset import Subset
35
36
from ote_sdk .entities .task_environment import TaskEnvironment
36
37
from ote_sdk .entities .train_parameters import TrainParameters , UpdateProgressCallback
37
- from mmcv .utils import ConfigDict
38
38
39
39
from ote_cli .datasets import get_dataset_class
40
40
from ote_cli .utils .importing import get_impl_class
@@ -280,7 +280,7 @@ def run_hpo_trainer(
280
280
281
281
# make callback to report score to hpopt every epoch
282
282
train_param = TrainParameters (
283
- False , HpoCallback (hp_config , hp_config ["metric" ], task ), ModelSavedCallback ()
283
+ False , HpoCallback (hp_config , hp_config ["metric" ], task ), None
284
284
)
285
285
286
286
task .train (dataset = dataset , output_model = output_model , train_parameters = train_param )
@@ -349,19 +349,6 @@ def set_resume_path_to_config(self, resume_path):
349
349
350
350
def prepare_hpo (self , hp_config ):
351
351
"""update config of the each task framework for the HPO"""
352
- # if _is_cls_framework_task(self._task_type):
353
- # # pylint: disable=attribute-defined-outside-init
354
- # self._scratch_space = _get_hpo_trial_workdir(hp_config)
355
- # self._cfg.data.save_dir = self._scratch_space
356
- # self._cfg.model.save_all_chkpts = True
357
- # elif _is_det_framework_task(self._task_type) or _is_seg_framework_task(
358
- # self._task_type
359
- # ):
360
- # self._config.work_dir = _get_hpo_trial_workdir(hp_config)
361
- # self._config.checkpoint_config["max_keep_ckpts"] = (
362
- # hp_config["iterations"] + 10
363
- # )
364
- # self._config.checkpoint_config["interval"] = 1
365
352
if (
366
353
_is_cls_framework_task (self ._task_type )
367
354
or _is_det_framework_task (self ._task_type )
@@ -373,7 +360,9 @@ def prepare_hpo(self, hp_config):
373
360
)
374
361
)
375
362
self .set_override_configurations (cfg )
376
- self ._output_path = _get_hpo_trial_workdir (hp_config )
363
+ self ._output_path = _get_hpo_trial_workdir ( # pylint: disable=attribute-defined-outside-init
364
+ hp_config
365
+ )
377
366
378
367
def prepare_saving_initial_weight (self , save_path ):
379
368
"""add a hook which saves initial model weight before training"""
@@ -382,11 +371,10 @@ def prepare_saving_initial_weight(self, save_path):
382
371
or _is_det_framework_task (task_type )
383
372
or _is_seg_framework_task (task_type )
384
373
):
385
- cfg = { "custom_hooks" : [
386
- ConfigDict (
387
- type = "SaveInitialWeightHook" ,
388
- save_path = save_path
389
- )]
374
+ cfg = {
375
+ "custom_hooks" : [
376
+ ConfigDict (type = "SaveInitialWeightHook" , save_path = save_path )
377
+ ]
390
378
}
391
379
self .set_override_configurations (cfg )
392
380
else :
@@ -395,7 +383,6 @@ def prepare_saving_initial_weight(self, save_path):
395
383
"initial weight should be saved before HPO."
396
384
)
397
385
398
-
399
386
return HpoTrainTask
400
387
401
388
@@ -446,7 +433,8 @@ def _load_hpopt_config(file_path):
446
433
def _get_best_model_weight_path (hpo_dir : str , trial_num : str , task_type : TaskType ):
447
434
"""Return best model weight from HPO trial directory"""
448
435
best_weight_path = None
449
- if (_is_cls_framework_task (task_type )
436
+ if (
437
+ _is_cls_framework_task (task_type )
450
438
or _is_det_framework_task (task_type )
451
439
or _is_seg_framework_task (task_type )
452
440
):
@@ -457,7 +445,7 @@ def _get_best_model_weight_path(hpo_dir: str, trial_num: str, task_type: TaskTyp
457
445
break
458
446
elif _is_anomaly_framework_task (task_type ):
459
447
# TODO need to implement later
460
- best_weight_path = ""
448
+ pass
461
449
462
450
return best_weight_path
463
451
@@ -500,14 +488,6 @@ def __call__(self, progress: float, score: Optional[float] = None):
500
488
self .hpo_task .cancel_training ()
501
489
502
490
503
- class ModelSavedCallback (Callable ):
504
- def __call__ (self , path : str , event : str ):
505
- if event == "initialized" :
506
- print (
507
- f"********** called model saved callback ({ path } , { event } ) **************"
508
- )
509
-
510
-
511
491
class HpoManager :
512
492
"""Manage overall HPO process"""
513
493
@@ -555,7 +535,6 @@ def __init__(
555
535
train_dataset_size = len (dataset .get_subset (Subset .TRAINING ))
556
536
val_dataset_size = len (dataset .get_subset (Subset .VALIDATION ))
557
537
558
-
559
538
# make batch size range lower than train set size
560
539
env_hp = self .environment .get_hyper_parameters ()
561
540
if (
@@ -627,27 +606,6 @@ def __init__(
627
606
# Prevent each trials from being stopped during warmup stage
628
607
batch_size = default_hyper_parameters .get (batch_size_name )
629
608
if "min_iterations" not in hpopt_cfg and batch_size is not None :
630
- # if _is_cls_framework_task(task_type):
631
- # with open(
632
- # osp.join(
633
- # osp.dirname(
634
- # self.environment.model_template.model_template_path
635
- # ),
636
- # "model.yaml",
637
- # ),
638
- # "r",
639
- # encoding="utf-8",
640
- # ) as f:
641
- # model_yaml = yaml.safe_load(f)
642
- # if "warmup" in model_yaml["train"]:
643
- # hpopt_arguments["min_iterations"] = ceil(
644
- # model_yaml["train"]["warmup"]
645
- # * batch_size
646
- # / train_dataset_size
647
- # )
648
- # elif _is_det_framework_task(task_type) or _is_seg_framework_task(
649
- # task_type
650
- # ):
651
609
if (
652
610
_is_cls_framework_task (task_type )
653
611
or _is_det_framework_task (task_type )
@@ -798,14 +756,6 @@ def run(self):
798
756
for key , val in self .fixed_hp .items ():
799
757
best_config [key ] = val
800
758
801
- # TODO: is it needed here?
802
- # # finetune stage resumes hpo trial, so warmup isn't needed
803
- # if task_type == TaskType.DETECTION:
804
- # best_config["learning_parameters.learning_rate_warmup_iters"] = 0
805
- # if task_type == TaskType.SEGMENTATION:
806
- # best_config["learning_parameters.learning_rate_fixed_iters"] = 0
807
- # best_config["learning_parameters.learning_rate_warmup_iters"] = 0
808
-
809
759
hyper_parameters = self .environment .get_hyper_parameters ()
810
760
HpoManager .set_hyperparameter (hyper_parameters , best_config )
811
761
@@ -866,7 +816,8 @@ def get_num_full_iterations(environment):
866
816
867
817
task_type = environment .model_template .task_type
868
818
params = environment .get_hyper_parameters ()
869
- if (_is_cls_framework_task (task_type )
819
+ if (
820
+ _is_cls_framework_task (task_type )
870
821
or _is_det_framework_task (task_type )
871
822
or _is_seg_framework_task (task_type )
872
823
):
0 commit comments