@@ -50,8 +50,11 @@ def check_hpopt_available():
50
50
def run_hpo (args , environment , dataset , task_type ):
51
51
"""Update the environment with better hyper-parameters found by HPO"""
52
52
if check_hpopt_available ():
53
- if task_type not in {TaskType .CLASSIFICATION , TaskType .DETECTION ,
54
- TaskType .SEGMENTATION }:
53
+ if task_type not in {
54
+ TaskType .CLASSIFICATION ,
55
+ TaskType .DETECTION ,
56
+ TaskType .SEGMENTATION ,
57
+ }:
55
58
print (
56
59
"Currently supported task types are classification and detection."
57
60
f"{ task_type } is not supported yet."
@@ -142,18 +145,26 @@ def run_hpo_trainer(
142
145
eph_comp = [
143
146
hyper_parameters .learning_parameters .learning_rate_fixed_iters ,
144
147
hyper_parameters .learning_parameters .learning_rate_warmup_iters ,
145
- hyper_parameters .learning_parameters .num_iters
148
+ hyper_parameters .learning_parameters .num_iters ,
146
149
]
147
150
148
- total_eph = sum ( eph_comp )
149
- eph_comp = list ( map (lambda x : x * hp_config ["iterations" ] / total_eph , eph_comp ) )
150
- s_ord = sorted ([ i for i in range ( len ( eph_comp ))], key = lambda k : eph_comp [ k ], reverse = True )
151
+ eph_comp = list (
152
+ map (lambda x : x * hp_config ["iterations" ] / sum ( eph_comp ) , eph_comp )
153
+ )
151
154
152
- for i in range (hp_config ["iterations" ] - sum (map (lambda x : int (x ), eph_comp ))):
153
- eph_comp [s_ord [i ]] += 1
155
+ for val in sorted (
156
+ list (range (len (eph_comp ))),
157
+ key = lambda k : eph_comp [k ] - int (eph_comp [k ]),
158
+ reverse = True ,
159
+ )[: hp_config ["iterations" ] - sum (map (int , eph_comp ))]:
160
+ eph_comp [val ] += 1
154
161
155
- hyper_parameters .learning_parameters .learning_rate_fixed_iters = int (eph_comp [0 ])
156
- hyper_parameters .learning_parameters .learning_rate_warmup_iters = int (eph_comp [1 ])
162
+ hyper_parameters .learning_parameters .learning_rate_fixed_iters = int (
163
+ eph_comp [0 ]
164
+ )
165
+ hyper_parameters .learning_parameters .learning_rate_warmup_iters = int (
166
+ eph_comp [1 ]
167
+ )
157
168
hyper_parameters .learning_parameters .num_iters = int (eph_comp [2 ])
158
169
159
170
# set hyper-parameters and print them
@@ -648,7 +659,7 @@ def find_class(self, module_name, class_name):
648
659
def main ():
649
660
"""Run run_hpo_trainer with a pickle file"""
650
661
hp_config = None
651
- sys .path [0 ] = "" # to prevent importing nncf from this directory
662
+ sys .path [0 ] = "" # to prevent importing nncf from this directory
652
663
653
664
try :
654
665
with open (sys .argv [1 ], "rb" ) as pfile :
0 commit comments