Skip to content

Commit abafaa3

Browse files
Merge pull request #980 from openvinotoolkit/leo/modeltemplate-isglobal
[OTE_SDK] expand ModelTemplate.is_global
2 parents 4fc20a2 + 5060c88 commit abafaa3

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

ote_sdk/ote_sdk/entities/model_template.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,10 @@ def is_task_global(self) -> bool:
424424
"""
425425
Returns ``True`` if the task is global task i.e. if task produces global labels
426426
"""
427-
return self.task_type in [TaskType.CLASSIFICATION]
427+
return self.task_type in (
428+
TaskType.CLASSIFICATION,
429+
TaskType.ANOMALY_CLASSIFICATION,
430+
)
428431

429432
def supports_auto_hpo(self) -> bool:
430433
"""

ote_sdk/ote_sdk/tests/entities/test_model_template.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -965,27 +965,36 @@ def test_model_template_is_task_global(self):
965965
Test passes if is_task_global method of ModelTemplate object returns expected bool values related to
966966
task_type attribute
967967
<b>Steps</b>
968-
1. Check is_task_global method returns True if task_type equal to CLASSIFICATION
969-
2. Check is_task_global method returns False if task_type not equal to CLASSIFICATION
968+
1. Check is_task_global method returns True if task_type equal to CLASSIFICATION or ANOMALY_CLASSIFICATION
969+
2. Check is_task_global method returns False if task_type not equal to CLASSIFICATION or ANOMALY_CLASSIFICATION
970970
"""
971-
# Check is_task_global method returns True
972-
default_parameters = self.default_model_parameters()
973-
task_global_parameters = dict(default_parameters)
974-
task_global_parameters["task_type"] = TaskType.CLASSIFICATION
975-
task_global_model_template = ModelTemplate(**task_global_parameters)
976-
assert task_global_model_template.is_task_global()
977-
# Check is_task_global method returns False
971+
# Check is_task_global method returns True for CLASSIFICATION and ANOMALY_CLASSIFICATION
972+
for global_task_type in (
973+
TaskType.CLASSIFICATION,
974+
TaskType.ANOMALY_CLASSIFICATION,
975+
):
976+
default_parameters = self.default_model_parameters()
977+
task_global_parameters = dict(default_parameters)
978+
task_global_parameters["task_type"] = global_task_type
979+
task_global_model_template = ModelTemplate(**task_global_parameters)
980+
assert (
981+
task_global_model_template.is_task_global()
982+
), f"Expected True value returned by is_task_global for {global_task_type}"
983+
# Check is_task_global method returns False for the other tasks
978984
non_global_task_parameters = dict(default_parameters)
979985
non_global_tasks_list = []
980986
for task_type in TaskType:
981-
if task_type != TaskType.CLASSIFICATION:
987+
if task_type not in (
988+
TaskType.CLASSIFICATION,
989+
TaskType.ANOMALY_CLASSIFICATION,
990+
):
982991
non_global_tasks_list.append(task_type)
983992
for non_global_task in non_global_tasks_list:
984993
non_global_task_parameters["task_type"] = non_global_task
985994
non_global_task_template = ModelTemplate(**non_global_task_parameters)
986995
assert not non_global_task_template.is_task_global(), (
987-
f"Expected False value returned by is_task_global method for {non_global_task}, only CLASSIFICATION "
988-
f"task type is global"
996+
f"Expected False value returned by is_task_global method for {non_global_task}, "
997+
f"only CLASSIFICATION and ANOMALY_CLASSIFICATION task types are global"
989998
)
990999

9911000

0 commit comments

Comments
 (0)