From ed18cbe0f152c300d543796d3b76505ae084977b Mon Sep 17 00:00:00 2001 From: "Kang, Harim" Date: Wed, 4 Sep 2024 15:59:01 +0900 Subject: [PATCH] Fix binary classification --- src/otx/core/metrics/accuracy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/otx/core/metrics/accuracy.py b/src/otx/core/metrics/accuracy.py index 1ddd0ce4b99..2f6139ba364 100644 --- a/src/otx/core/metrics/accuracy.py +++ b/src/otx/core/metrics/accuracy.py @@ -346,8 +346,10 @@ def compute(self) -> torch.Tensor: def _multi_class_cls_metric_callable(label_info: LabelInfo) -> MetricCollection: + num_classes = label_info.num_classes + task = "binary" if num_classes == 1 else "multiclass" return MetricCollection( - {"accuracy": TorchmetricAcc(task="multiclass", num_classes=label_info.num_classes)}, + {"accuracy": TorchmetricAcc(task=task, num_classes=num_classes)}, )