Skip to content

Commit 2aab8df

Browse files
committed
use is_anomalous attribute instead of string matching
1 parent 38a4d88 commit 2aab8df

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

ote_sdk/ote_sdk/usecases/exportable_code/prediction_to_annotation_converter.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,8 @@ class AnomalyClassificationToAnnotationConverter(IPredictionToAnnotationConverte
259259

260260
def __init__(self, label_schema: LabelSchemaEntity):
261261
labels = label_schema.get_labels(include_empty=False)
262-
self.normal_label = [label for label in labels if label.name == "Normal"][0]
263-
self.anomalous_label = [label for label in labels if label.name == "Anomalous"][
264-
0
265-
]
262+
self.normal_label = [label for label in labels if not label.is_anomalous][0]
263+
self.anomalous_label = [label for label in labels if label.is_anomalous][0]
266264

267265
def convert_to_annotation(
268266
self, predictions: np.ndarray, metadata: Dict[str, Any]
@@ -290,10 +288,8 @@ class AnomalySegmentationToAnnotationConverter(IPredictionToAnnotationConverter)
290288

291289
def __init__(self, label_schema: LabelSchemaEntity):
292290
labels = label_schema.get_labels(include_empty=False)
293-
self.normal_label = [label for label in labels if label.name == "Normal"][0]
294-
self.anomalous_label = [label for label in labels if label.name == "Anomalous"][
295-
0
296-
]
291+
self.normal_label = [label for label in labels if not label.is_anomalous][0]
292+
self.anomalous_label = [label for label in labels if label.is_anomalous][0]
297293
self.label_map = {0: self.normal_label, 1: self.anomalous_label}
298294

299295
def convert_to_annotation(
@@ -327,10 +323,8 @@ def __init__(self, label_schema: LabelSchemaEntity):
327323
:param label_schema: Label Schema containing the label info of the task
328324
"""
329325
labels = label_schema.get_labels(include_empty=False)
330-
self.normal_label = [label for label in labels if label.name == "Normal"][0]
331-
self.anomalous_label = [label for label in labels if label.name == "Anomalous"][
332-
0
333-
]
326+
self.normal_label = [label for label in labels if not label.is_anomalous][0]
327+
self.anomalous_label = [label for label in labels if label.is_anomalous][0]
334328
self.label_map = {0: self.normal_label, 1: self.anomalous_label}
335329

336330
def convert_to_annotation(

0 commit comments

Comments
 (0)