From f1e395adfefd2347bea6b2c8713181c05bac867b Mon Sep 17 00:00:00 2001 From: "Kang, Harim" Date: Wed, 4 Sep 2024 16:23:57 +0900 Subject: [PATCH] Add unit-tests --- tests/unit/core/metrics/test_accuracy.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/unit/core/metrics/test_accuracy.py b/tests/unit/core/metrics/test_accuracy.py index d04d253575f..8370fee09f6 100644 --- a/tests/unit/core/metrics/test_accuracy.py +++ b/tests/unit/core/metrics/test_accuracy.py @@ -9,9 +9,11 @@ HlabelAccuracy, MixedHLabelAccuracy, MulticlassAccuracywithLabelGroup, + MultiClassClsMetricCallable, MultilabelAccuracywithLabelGroup, ) from otx.core.types.label import HLabelInfo, LabelInfo +from torchmetrics.classification.accuracy import BinaryAccuracy, MulticlassAccuracy class TestAccuracy: @@ -45,6 +47,16 @@ def test_multiclass_accuracy(self, fxt_multiclass_labelinfo: LabelInfo) -> None: acc = result["accuracy"] assert round(acc.item(), 3) == 0.792 + def test_default_multi_class_cls_metric_callable(self, fxt_multiclass_labelinfo: LabelInfo) -> None: + assert fxt_multiclass_labelinfo.num_classes > 1 + metric = MultiClassClsMetricCallable(fxt_multiclass_labelinfo) + assert isinstance(metric.accuracy, MulticlassAccuracy) + + one_class_label_info = LabelInfo(label_names=["class1"], label_groups=[["class1"]]) + assert one_class_label_info.num_classes == 1 + binary_metric = MultiClassClsMetricCallable(one_class_label_info) + assert isinstance(binary_metric.accuracy, BinaryAccuracy) + def test_multilabel_accuracy(self, fxt_multilabel_labelinfo: LabelInfo) -> None: """Check whether accuracy is same with OTX1.x version.""" preds = [