@@ -50,7 +50,11 @@ def check_hpopt_available():
50
50
def run_hpo (args , environment , dataset , task_type ):
51
51
"""Update the environment with better hyper-parameters found by HPO"""
52
52
if check_hpopt_available ():
53
- if task_type not in {TaskType .CLASSIFICATION , TaskType .DETECTION }:
53
+ if task_type not in {
54
+ TaskType .CLASSIFICATION ,
55
+ TaskType .DETECTION ,
56
+ TaskType .SEGMENTATION ,
57
+ }:
54
58
print (
55
59
"Currently supported task types are classification and detection."
56
60
f"{ task_type } is not supported yet."
@@ -135,8 +139,33 @@ def run_hpo_trainer(
135
139
# set epoch
136
140
if task_type == TaskType .CLASSIFICATION :
137
141
(hyper_parameters .learning_parameters .max_num_epochs ) = hp_config ["iterations" ]
138
- elif task_type in ( TaskType .DETECTION , TaskType . SEGMENTATION ) :
142
+ elif task_type == TaskType .DETECTION :
139
143
hyper_parameters .learning_parameters .num_iters = hp_config ["iterations" ]
144
+ elif task_type == TaskType .SEGMENTATION :
145
+ eph_comp = [
146
+ hyper_parameters .learning_parameters .learning_rate_fixed_iters ,
147
+ hyper_parameters .learning_parameters .learning_rate_warmup_iters ,
148
+ hyper_parameters .learning_parameters .num_iters ,
149
+ ]
150
+
151
+ eph_comp = list (
152
+ map (lambda x : x * hp_config ["iterations" ] / sum (eph_comp ), eph_comp )
153
+ )
154
+
155
+ for val in sorted (
156
+ list (range (len (eph_comp ))),
157
+ key = lambda k : eph_comp [k ] - int (eph_comp [k ]),
158
+ reverse = True ,
159
+ )[: hp_config ["iterations" ] - sum (map (int , eph_comp ))]:
160
+ eph_comp [val ] += 1
161
+
162
+ hyper_parameters .learning_parameters .learning_rate_fixed_iters = int (
163
+ eph_comp [0 ]
164
+ )
165
+ hyper_parameters .learning_parameters .learning_rate_warmup_iters = int (
166
+ eph_comp [1 ]
167
+ )
168
+ hyper_parameters .learning_parameters .num_iters = int (eph_comp [2 ])
140
169
141
170
# set hyper-parameters and print them
142
171
HpoManager .set_hyperparameter (hyper_parameters , hp_config ["params" ])
@@ -630,6 +659,7 @@ def find_class(self, module_name, class_name):
630
659
def main ():
631
660
"""Run run_hpo_trainer with a pickle file"""
632
661
hp_config = None
662
+ sys .path [0 ] = "" # to prevent importing nncf from this directory
633
663
634
664
try :
635
665
with open (sys .argv [1 ], "rb" ) as pfile :
0 commit comments