Skip to content

Commit c1c3753

Browse files
authored
Merge pull request #919 from openvinotoolkit/es/hpo_segmentation
[HPO] enable HPO with segmentation
2 parents e8d24d0 + 14c1d94 commit c1c3753

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

ote_cli/ote_cli/utils/hpo.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ def check_hpopt_available():
5050
def run_hpo(args, environment, dataset, task_type):
5151
"""Update the environment with better hyper-parameters found by HPO"""
5252
if check_hpopt_available():
53-
if task_type not in {TaskType.CLASSIFICATION, TaskType.DETECTION}:
53+
if task_type not in {
54+
TaskType.CLASSIFICATION,
55+
TaskType.DETECTION,
56+
TaskType.SEGMENTATION,
57+
}:
5458
print(
5559
"Currently supported task types are classification and detection."
5660
f"{task_type} is not supported yet."
@@ -135,8 +139,33 @@ def run_hpo_trainer(
135139
# set epoch
136140
if task_type == TaskType.CLASSIFICATION:
137141
(hyper_parameters.learning_parameters.max_num_epochs) = hp_config["iterations"]
138-
elif task_type in (TaskType.DETECTION, TaskType.SEGMENTATION):
142+
elif task_type == TaskType.DETECTION:
139143
hyper_parameters.learning_parameters.num_iters = hp_config["iterations"]
144+
elif task_type == TaskType.SEGMENTATION:
145+
eph_comp = [
146+
hyper_parameters.learning_parameters.learning_rate_fixed_iters,
147+
hyper_parameters.learning_parameters.learning_rate_warmup_iters,
148+
hyper_parameters.learning_parameters.num_iters,
149+
]
150+
151+
eph_comp = list(
152+
map(lambda x: x * hp_config["iterations"] / sum(eph_comp), eph_comp)
153+
)
154+
155+
for val in sorted(
156+
list(range(len(eph_comp))),
157+
key=lambda k: eph_comp[k] - int(eph_comp[k]),
158+
reverse=True,
159+
)[: hp_config["iterations"] - sum(map(int, eph_comp))]:
160+
eph_comp[val] += 1
161+
162+
hyper_parameters.learning_parameters.learning_rate_fixed_iters = int(
163+
eph_comp[0]
164+
)
165+
hyper_parameters.learning_parameters.learning_rate_warmup_iters = int(
166+
eph_comp[1]
167+
)
168+
hyper_parameters.learning_parameters.num_iters = int(eph_comp[2])
140169

141170
# set hyper-parameters and print them
142171
HpoManager.set_hyperparameter(hyper_parameters, hp_config["params"])
@@ -630,6 +659,7 @@ def find_class(self, module_name, class_name):
630659
def main():
631660
"""Run run_hpo_trainer with a pickle file"""
632661
hp_config = None
662+
sys.path[0] = "" # to prevent importing nncf from this directory
633663

634664
try:
635665
with open(sys.argv[1], "rb") as pfile:

tests/ote_cli/test_ote_cli_tools_segmentation.py

+6
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
ote_eval_deployment_testing,
3333
ote_eval_openvino_testing,
3434
ote_eval_testing,
35+
ote_hpo_testing,
3536
ote_train_testing,
3637
ote_export_testing,
3738
pot_optimize_testing,
@@ -118,6 +119,11 @@ def test_ote_eval_deployment(self, template):
118119
def test_ote_demo_deployment(self, template):
119120
ote_demo_deployment_testing(template, root, ote_dir, args)
120121

122+
@e2e_pytest_component
123+
@pytest.mark.parametrize("template", templates, ids=templates_ids)
124+
def test_ote_hpo(self, template):
125+
ote_hpo_testing(template, root, ote_dir, args)
126+
121127
@e2e_pytest_component
122128
@pytest.mark.parametrize("template", templates, ids=templates_ids)
123129
def test_nncf_optimize(self, template):

0 commit comments

Comments
 (0)