Skip to content

Commit 14c1d94

Browse files
committed
make code fit to format
1 parent 6113bcd commit 14c1d94

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

ote_cli/ote_cli/utils/hpo.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ def check_hpopt_available():
5050
def run_hpo(args, environment, dataset, task_type):
5151
"""Update the environment with better hyper-parameters found by HPO"""
5252
if check_hpopt_available():
53-
if task_type not in {TaskType.CLASSIFICATION, TaskType.DETECTION,
54-
TaskType.SEGMENTATION}:
53+
if task_type not in {
54+
TaskType.CLASSIFICATION,
55+
TaskType.DETECTION,
56+
TaskType.SEGMENTATION,
57+
}:
5558
print(
5659
"Currently supported task types are classification and detection."
5760
f"{task_type} is not supported yet."
@@ -142,18 +145,26 @@ def run_hpo_trainer(
142145
eph_comp = [
143146
hyper_parameters.learning_parameters.learning_rate_fixed_iters,
144147
hyper_parameters.learning_parameters.learning_rate_warmup_iters,
145-
hyper_parameters.learning_parameters.num_iters
148+
hyper_parameters.learning_parameters.num_iters,
146149
]
147150

148-
total_eph = sum(eph_comp)
149-
eph_comp = list(map(lambda x: x * hp_config["iterations"] / total_eph, eph_comp))
150-
s_ord = sorted([i for i in range(len(eph_comp))], key=lambda k: eph_comp[k], reverse=True)
151+
eph_comp = list(
152+
map(lambda x: x * hp_config["iterations"] / sum(eph_comp), eph_comp)
153+
)
151154

152-
for i in range(hp_config["iterations"] - sum(map(lambda x: int(x), eph_comp))):
153-
eph_comp[s_ord[i]] += 1
155+
for val in sorted(
156+
list(range(len(eph_comp))),
157+
key=lambda k: eph_comp[k] - int(eph_comp[k]),
158+
reverse=True,
159+
)[: hp_config["iterations"] - sum(map(int, eph_comp))]:
160+
eph_comp[val] += 1
154161

155-
hyper_parameters.learning_parameters.learning_rate_fixed_iters = int(eph_comp[0])
156-
hyper_parameters.learning_parameters.learning_rate_warmup_iters = int(eph_comp[1])
162+
hyper_parameters.learning_parameters.learning_rate_fixed_iters = int(
163+
eph_comp[0]
164+
)
165+
hyper_parameters.learning_parameters.learning_rate_warmup_iters = int(
166+
eph_comp[1]
167+
)
157168
hyper_parameters.learning_parameters.num_iters = int(eph_comp[2])
158169

159170
# set hyper-parameters and print them
@@ -648,7 +659,7 @@ def find_class(self, module_name, class_name):
648659
def main():
649660
"""Run run_hpo_trainer with a pickle file"""
650661
hp_config = None
651-
sys.path[0] = "" # to prevent importing nncf from this directory
662+
sys.path[0] = "" # to prevent importing nncf from this directory
652663

653664
try:
654665
with open(sys.argv[1], "rb") as pfile:

0 commit comments

Comments
 (0)