From 2aab8df2a1dd049fc62f5329e9719708b1365f7a Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 24 May 2022 17:56:52 +0200 Subject: [PATCH 1/2] use is_anomalous attribute instead of string matching --- .../prediction_to_annotation_converter.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/ote_sdk/ote_sdk/usecases/exportable_code/prediction_to_annotation_converter.py b/ote_sdk/ote_sdk/usecases/exportable_code/prediction_to_annotation_converter.py index 853537085b5..4d764905e47 100644 --- a/ote_sdk/ote_sdk/usecases/exportable_code/prediction_to_annotation_converter.py +++ b/ote_sdk/ote_sdk/usecases/exportable_code/prediction_to_annotation_converter.py @@ -259,10 +259,8 @@ class AnomalyClassificationToAnnotationConverter(IPredictionToAnnotationConverte def __init__(self, label_schema: LabelSchemaEntity): labels = label_schema.get_labels(include_empty=False) - self.normal_label = [label for label in labels if label.name == "Normal"][0] - self.anomalous_label = [label for label in labels if label.name == "Anomalous"][ - 0 - ] + self.normal_label = [label for label in labels if not label.is_anomalous][0] + self.anomalous_label = [label for label in labels if label.is_anomalous][0] def convert_to_annotation( self, predictions: np.ndarray, metadata: Dict[str, Any] @@ -290,10 +288,8 @@ class AnomalySegmentationToAnnotationConverter(IPredictionToAnnotationConverter) def __init__(self, label_schema: LabelSchemaEntity): labels = label_schema.get_labels(include_empty=False) - self.normal_label = [label for label in labels if label.name == "Normal"][0] - self.anomalous_label = [label for label in labels if label.name == "Anomalous"][ - 0 - ] + self.normal_label = [label for label in labels if not label.is_anomalous][0] + self.anomalous_label = [label for label in labels if label.is_anomalous][0] self.label_map = {0: self.normal_label, 1: self.anomalous_label} def convert_to_annotation( @@ -327,10 +323,8 @@ def __init__(self, label_schema: LabelSchemaEntity): :param label_schema: Label Schema containing the label info of the task """ labels = label_schema.get_labels(include_empty=False) - self.normal_label = [label for label in labels if label.name == "Normal"][0] - self.anomalous_label = [label for label in labels if label.name == "Anomalous"][ - 0 - ] + self.normal_label = [label for label in labels if not label.is_anomalous][0] + self.anomalous_label = [label for label in labels if label.is_anomalous][0] self.label_map = {0: self.normal_label, 1: self.anomalous_label} def convert_to_annotation( From 16f2138b8af77b65f4f4b809328749c043bad5e4 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 30 May 2022 11:35:18 +0200 Subject: [PATCH 2/2] add is_anomalous attribute to labels in test cases --- ...test_prediction_to_annotation_converter.py | 38 ++++++++++++++++--- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/ote_sdk/ote_sdk/tests/usecases/exportable_code/test_prediction_to_annotation_converter.py b/ote_sdk/ote_sdk/tests/usecases/exportable_code/test_prediction_to_annotation_converter.py index 0788c0d0ec2..2d3a176c869 100644 --- a/ote_sdk/ote_sdk/tests/usecases/exportable_code/test_prediction_to_annotation_converter.py +++ b/ote_sdk/ote_sdk/tests/usecases/exportable_code/test_prediction_to_annotation_converter.py @@ -293,7 +293,10 @@ def test_create_converter(self): name="Normal", domain=Domain.ANOMALY_CLASSIFICATION, id=ID("1") ), LabelEntity( - name="Anomalous", domain=Domain.ANOMALY_CLASSIFICATION, id=ID("2") + name="Anomalous", + domain=Domain.ANOMALY_CLASSIFICATION, + id=ID("2"), + is_anomalous=True, ), ] label_group = LabelGroup( @@ -310,7 +313,12 @@ def test_create_converter(self): # "ANOMALY_DETECTION" is specified as "converter_type" labels = [ LabelEntity(name="Normal", domain=Domain.ANOMALY_DETECTION, id=ID("1")), - LabelEntity(name="Anomalous", domain=Domain.ANOMALY_DETECTION, id=ID("2")), + LabelEntity( + name="Anomalous", + domain=Domain.ANOMALY_DETECTION, + id=ID("2"), + is_anomalous=True, + ), ] label_group = LabelGroup(name="Anomaly detection labels group", labels=labels) label_schema = LabelSchemaEntity(label_groups=[label_group]) @@ -325,7 +333,10 @@ def test_create_converter(self): labels = [ LabelEntity(name="Normal", domain=Domain.ANOMALY_SEGMENTATION, id=ID("1")), LabelEntity( - name="Anomalous", domain=Domain.ANOMALY_SEGMENTATION, id=ID("2") + name="Anomalous", + domain=Domain.ANOMALY_SEGMENTATION, + id=ID("2"), + is_anomalous=True, ), ] label_group = LabelGroup(name="Anomaly detection labels group", labels=labels) @@ -947,8 +958,18 @@ def test_anomaly_classification_to_annotation_init( non_empty_labels = [ LabelEntity(name="Normal", domain=Domain.CLASSIFICATION, id=ID("1")), LabelEntity(name="Normal", domain=Domain.CLASSIFICATION, id=ID("2")), - LabelEntity(name="Anomalous", domain=Domain.CLASSIFICATION, id=ID("1")), - LabelEntity(name="Anomalous", domain=Domain.CLASSIFICATION, id=ID("2")), + LabelEntity( + name="Anomalous", + domain=Domain.CLASSIFICATION, + id=ID("1"), + is_anomalous=True, + ), + LabelEntity( + name="Anomalous", + domain=Domain.CLASSIFICATION, + id=ID("2"), + is_anomalous=True, + ), ] label_group = LabelGroup( name="Classification labels group", labels=non_empty_labels @@ -1030,7 +1051,12 @@ def check_annotation(actual_annotation: Annotation, expected_labels: list): non_empty_labels = [ LabelEntity(name="Normal", domain=Domain.CLASSIFICATION, id=ID("1")), - LabelEntity(name="Anomalous", domain=Domain.CLASSIFICATION, id=ID("2")), + LabelEntity( + name="Anomalous", + domain=Domain.CLASSIFICATION, + id=ID("2"), + is_anomalous=True, + ), ] label_group = LabelGroup( name="Anomaly classification labels group", labels=non_empty_labels