From c0fd4b715bde65df5810d073703fb77123c0bff8 Mon Sep 17 00:00:00 2001 From: manivoxel51 Date: Tue, 14 Jan 2025 15:13:31 -0800 Subject: [PATCH] Add correct/incorrect to summary card for classification/non-binary evaluations --- .../NativeModelEvaluationView/Evaluation.tsx | 20 +++++++++++++++++++ plugins/panels/model_evaluation/__init__.py | 15 ++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx b/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx index 5f10383da7..9a4eb9d1c4 100644 --- a/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx +++ b/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx @@ -210,6 +210,8 @@ export default function Evaluation(props: EvaluationProps) { const isBinaryClassification = evaluationType === "classification" && evaluationMethod === "binary"; const showTpFpFn = isObjectDetection || isBinaryClassification; + const isNoneBinaryClassification = + isClassification && evaluationMethod !== "binary"; const infoRows = [ { id: "evaluation_key", @@ -465,6 +467,24 @@ export default function Evaluation(props: EvaluationProps) { : false, hide: !showTpFpFn, }, + { + id: true, + property: "Correct", + value: evaluationMetrics.num_correct, + compareValue: compareEvaluationMetrics.num_correct, + lesserIsBetter: false, + filterable: true, + hide: !isNoneBinaryClassification, + }, + { + id: false, + property: "Incorrect", + value: evaluationMetrics.num_incorrect, + compareValue: compareEvaluationMetrics.num_incorrect, + lesserIsBetter: false, + filterable: true, + hide: !isNoneBinaryClassification, + }, ]; const perClassPerformance = {}; diff --git a/plugins/panels/model_evaluation/__init__.py b/plugins/panels/model_evaluation/__init__.py index ed71f85ec8..8677cc6dfc 100644 --- a/plugins/panels/model_evaluation/__init__.py +++ b/plugins/panels/model_evaluation/__init__.py @@ -325,6 +325,11 @@ def get_mask_targets(self, dataset, gt_field): return None + def get_correct_incorrect(self, results): + correct = np.count_nonzero(results.ypred == results.ytrue) + incorrect = np.count_nonzero(results.ypred != results.ytrue) + return correct, incorrect + def load_evaluation(self, ctx): view_state = ctx.panel.get_state("view") or {} eval_key = view_state.get("key") @@ -362,6 +367,16 @@ def load_evaluation(self, ctx): ) metrics["mAP"] = self.get_map(results) metrics["mAR"] = self.get_mar(results) + + if ( + info.config.type == "classification" + and info.config.method != "binary" + ): + ( + metrics["num_correct"], + metrics["num_incorrect"], + ) = self.get_correct_incorrect(results) + evaluation_data = { "metrics": metrics, "info": serialized_info,