Skip to content

Commit

Permalink
Fix binary classification
Browse files Browse the repository at this point in the history
  • Loading branch information
harimkang committed Sep 4, 2024
1 parent 112b2b2 commit ed18cbe
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/otx/core/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,10 @@ def compute(self) -> torch.Tensor:


def _multi_class_cls_metric_callable(label_info: LabelInfo) -> MetricCollection:
num_classes = label_info.num_classes
task = "binary" if num_classes == 1 else "multiclass"
return MetricCollection(
{"accuracy": TorchmetricAcc(task="multiclass", num_classes=label_info.num_classes)},
{"accuracy": TorchmetricAcc(task=task, num_classes=num_classes)},
)


Expand Down

0 comments on commit ed18cbe

Please sign in to comment.