From c1e8699b14e1ffb42579be7240c413471c260315 Mon Sep 17 00:00:00 2001
From: Ilya Matiach <ilmat@microsoft.com>
Date: Mon, 7 Aug 2023 16:29:50 -0400
Subject: [PATCH] add multiclass statistics to text and tabular RAI dashboards

---
 libs/core-ui/src/index.ts                     |   2 +-
 ...sUtils.ts => MulticlassStatisticsUtils.ts} |  26 ++--
 libs/core-ui/src/lib/util/StatisticsUtils.ts  |  53 ++------
 .../src/lib/util/StatisticsUtilsEnums.ts      |   8 +-
 .../datasets/MulticlassDnnModelDebugging.ts   |  10 +-
 ...tasetCohortsViewBasicElementsArePresent.ts |   6 +
 .../Controls/ModelOverview/ModelOverview.tsx  |  15 +--
 .../Controls/ModelOverview/StatsTableUtils.ts | 116 ++++++++----------
 8 files changed, 101 insertions(+), 135 deletions(-)
 rename libs/core-ui/src/lib/util/{ImageStatisticsUtils.ts => MulticlassStatisticsUtils.ts} (81%)

diff --git a/libs/core-ui/src/index.ts b/libs/core-ui/src/index.ts
index b2a0c2d5db..cf2faa2d24 100644
--- a/libs/core-ui/src/index.ts
+++ b/libs/core-ui/src/index.ts
@@ -38,7 +38,7 @@ export * from "./lib/util/getRandomId";
 export * from "./lib/util/getCohortFilterCount";
 export * from "./lib/util/getDependencyChartOptions";
 export * from "./lib/util/IGenericChartProps";
-export * from "./lib/util/ImageStatisticsUtils";
+export * from "./lib/util/MulticlassStatisticsUtils";
 export * from "./lib/util/initializeOfficeFabric";
 export * from "./lib/util/isNumber";
 export * from "./lib/util/ModelExplanationUtils";
diff --git a/libs/core-ui/src/lib/util/ImageStatisticsUtils.ts b/libs/core-ui/src/lib/util/MulticlassStatisticsUtils.ts
similarity index 81%
rename from libs/core-ui/src/lib/util/ImageStatisticsUtils.ts
rename to libs/core-ui/src/lib/util/MulticlassStatisticsUtils.ts
index f77e156cbc..4a82e6e408 100644
--- a/libs/core-ui/src/lib/util/ImageStatisticsUtils.ts
+++ b/libs/core-ui/src/lib/util/MulticlassStatisticsUtils.ts
@@ -8,15 +8,7 @@ import {
   TotalCohortSamples
 } from "../Interfaces/IStatistic";
 
-export enum ImageClassificationMetrics {
-  Accuracy = "accuracy",
-  MacroF1 = "f1",
-  MacroPrecision = "precision",
-  MacroRecall = "recall",
-  MicroF1 = "microF1",
-  MicroPrecision = "microPrecision",
-  MicroRecall = "microRecall"
-}
+import { MulticlassClassificationMetrics } from "./StatisticsUtilsEnums";
 
 interface IMicroMacroRetVal {
   macroScore: number;
@@ -64,7 +56,7 @@ export const generateMicroMacroMetrics = (
   };
 };
 
-export const generateImageStats: (
+export const generateMulticlassStats: (
   trueYs: number[],
   predYs: number[]
 ) => ILabeledStatistic[] = (
@@ -90,37 +82,37 @@ export const generateImageStats: (
       stat: predYs.length
     },
     {
-      key: ImageClassificationMetrics.Accuracy,
+      key: MulticlassClassificationMetrics.Accuracy,
       label: localization.Interpret.Statistics.accuracy,
       stat: accuracy
     },
     {
-      key: ImageClassificationMetrics.MicroPrecision,
+      key: MulticlassClassificationMetrics.MicroPrecision,
       label: localization.Interpret.Statistics.precision,
       stat: microP
     },
     {
-      key: ImageClassificationMetrics.MicroRecall,
+      key: MulticlassClassificationMetrics.MicroRecall,
       label: localization.Interpret.Statistics.recall,
       stat: microR
     },
     {
-      key: ImageClassificationMetrics.MicroF1,
+      key: MulticlassClassificationMetrics.MicroF1,
       label: localization.Interpret.Statistics.f1Score,
       stat: microF1
     },
     {
-      key: ImageClassificationMetrics.MacroPrecision,
+      key: MulticlassClassificationMetrics.MacroPrecision,
       label: localization.Interpret.Statistics.precision,
       stat: macroP
     },
     {
-      key: ImageClassificationMetrics.MacroRecall,
+      key: MulticlassClassificationMetrics.MacroRecall,
       label: localization.Interpret.Statistics.recall,
       stat: macroR
     },
     {
-      key: ImageClassificationMetrics.MacroF1,
+      key: MulticlassClassificationMetrics.MacroF1,
       label: localization.Interpret.Statistics.f1Score,
       stat: macroF1
     }
diff --git a/libs/core-ui/src/lib/util/StatisticsUtils.ts b/libs/core-ui/src/lib/util/StatisticsUtils.ts
index 27f9077557..cde9e2ca93 100644
--- a/libs/core-ui/src/lib/util/StatisticsUtils.ts
+++ b/libs/core-ui/src/lib/util/StatisticsUtils.ts
@@ -10,18 +10,14 @@ import {
 } from "../Interfaces/IStatistic";
 import { IsBinary } from "../util/ExplanationUtils";
 
-import { generateImageStats } from "./ImageStatisticsUtils";
 import { JointDataset } from "./JointDataset";
-import {
-  ClassificationEnum,
-  MulticlassClassificationEnum
-} from "./JointDatasetUtils";
+import { ClassificationEnum } from "./JointDatasetUtils";
+import { generateMulticlassStats } from "./MulticlassStatisticsUtils";
 import { generateMultilabelStats } from "./MultilabelStatisticsUtils";
 import { generateObjectDetectionStats } from "./ObjectDetectionStatisticsUtils";
 import { generateQuestionAnsweringStats } from "./QuestionAnsweringStatisticsUtils";
 import {
   BinaryClassificationMetrics,
-  MulticlassClassificationMetrics,
   RegressionMetrics
 } from "./StatisticsUtilsEnums";
 
@@ -147,27 +143,6 @@ const generateRegressionStats: (
   ];
 };
 
-const generateMulticlassStats: (outcomes: number[]) => ILabeledStatistic[] = (
-  outcomes: number[]
-): ILabeledStatistic[] => {
-  const correctCount = outcomes.filter(
-    (x) => x === MulticlassClassificationEnum.Correct
-  ).length;
-  const total = outcomes.length;
-  return [
-    {
-      key: TotalCohortSamples,
-      label: localization.Interpret.Statistics.samples,
-      stat: total
-    },
-    {
-      key: MulticlassClassificationMetrics.Accuracy,
-      label: localization.Interpret.Statistics.accuracy,
-      stat: correctCount / total
-    }
-  ];
-};
-
 export const generateMetrics: (
   jointDataset: JointDataset,
   selectionIndexes: number[][],
@@ -206,13 +181,6 @@ export const generateMetrics: (
       return generateRegressionStats(trueYSubset, predYSubset, errorsSubset);
     });
   }
-  if (modelType === ModelTypes.ImageMulticlass) {
-    return selectionIndexes.map((selectionArray) => {
-      const trueYSubset = selectionArray.map((i) => trueYs[i]);
-      const predYSubset = selectionArray.map((i) => predYs[i]);
-      return generateImageStats(trueYSubset, predYSubset);
-    });
-  }
   if (
     modelType === ModelTypes.ObjectDetection &&
     objectDetectionCache &&
@@ -225,12 +193,17 @@ export const generateMetrics: (
     );
   }
   const outcomes = jointDataset.unwrap(JointDataset.ClassificationError);
-  return selectionIndexes.map((selectionArray) => {
-    const outcomeSubset = selectionArray.map((i) => outcomes[i]);
-    if (IsBinary(modelType)) {
+  if (IsBinary(modelType)) {
+    return selectionIndexes.map((selectionArray) => {
+      const outcomeSubset = selectionArray.map((i) => outcomes[i]);
+
       return generateBinaryStats(outcomeSubset);
-    }
-    // modelType === ModelTypes.Multiclass
-    return generateMulticlassStats(outcomeSubset);
+    });
+  }
+  // modelType === ModelTypes.Multiclass
+  return selectionIndexes.map((selectionArray) => {
+    const trueYSubset = selectionArray.map((i) => trueYs[i]);
+    const predYSubset = selectionArray.map((i) => predYs[i]);
+    return generateMulticlassStats(trueYSubset, predYSubset);
   });
 };
diff --git a/libs/core-ui/src/lib/util/StatisticsUtilsEnums.ts b/libs/core-ui/src/lib/util/StatisticsUtilsEnums.ts
index f029438b03..6a1a51392b 100644
--- a/libs/core-ui/src/lib/util/StatisticsUtilsEnums.ts
+++ b/libs/core-ui/src/lib/util/StatisticsUtilsEnums.ts
@@ -19,5 +19,11 @@ export enum RegressionMetrics {
 }
 
 export enum MulticlassClassificationMetrics {
-  Accuracy = "accuracy"
+  Accuracy = "accuracy",
+  MacroF1 = "f1",
+  MacroPrecision = "precision",
+  MacroRecall = "recall",
+  MicroF1 = "microF1",
+  MicroPrecision = "microPrecision",
+  MicroRecall = "microRecall"
 }
diff --git a/libs/e2e/src/lib/describer/modelAssessment/datasets/MulticlassDnnModelDebugging.ts b/libs/e2e/src/lib/describer/modelAssessment/datasets/MulticlassDnnModelDebugging.ts
index 1adb1d917d..28daade702 100644
--- a/libs/e2e/src/lib/describer/modelAssessment/datasets/MulticlassDnnModelDebugging.ts
+++ b/libs/e2e/src/lib/describer/modelAssessment/datasets/MulticlassDnnModelDebugging.ts
@@ -63,7 +63,10 @@ export const MulticlassDnnModelDebugging = {
     initialCohorts: [
       {
         metrics: {
-          accuracy: "0.674"
+          accuracy: "0.674",
+          macroF1Score: "0.673",
+          macroPrecisionScore: "0.669",
+          macroRecallScore: "0.677"
         },
         name: "All data",
         sampleSize: "89"
@@ -71,7 +74,10 @@ export const MulticlassDnnModelDebugging = {
     ],
     newCohort: {
       metrics: {
-        accuracy: "0.67"
+        accuracy: "0.67",
+        macroF1Score: "0.671",
+        macroPrecisionScore: "0.666",
+        macroRecallScore: "0.675"
       },
       name: "CohortCreateE2E-multiclass-dnn",
       sampleSize: "88"
diff --git a/libs/e2e/src/lib/describer/modelAssessment/modelOverview/ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent.ts b/libs/e2e/src/lib/describer/modelAssessment/modelOverview/ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent.ts
index cc1c846203..6f00727607 100644
--- a/libs/e2e/src/lib/describer/modelAssessment/modelOverview/ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent.ts
+++ b/libs/e2e/src/lib/describer/modelAssessment/modelOverview/ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent.ts
@@ -75,6 +75,12 @@ export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
         "falseNegativeRate",
         "selectionRate"
       );
+    } else {
+      metricsOrder.push(
+        "macroF1Score",
+        "macroPrecisionScore",
+        "macroRecallScore"
+      );
     }
   }
 
diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx
index bb524b9176..cdb8c54aa3 100644
--- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx
+++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx
@@ -33,7 +33,6 @@ import {
   TelemetryLevels,
   TelemetryEventName,
   DatasetTaskType,
-  ImageClassificationMetrics,
   QuestionAnsweringMetrics,
   TotalCohortSamples
 } from "@responsible-ai/core-ui";
@@ -147,17 +146,13 @@ export class ModelOverview extends React.Component<
           BinaryClassificationMetrics.FalseNegativeRate,
           BinaryClassificationMetrics.SelectionRate
         ];
-      } else if (
-        this.context.dataset.task_type === DatasetTaskType.ImageClassification
-      ) {
+      } else {
         defaultSelectedMetrics = [
-          ImageClassificationMetrics.Accuracy,
-          ImageClassificationMetrics.MacroF1,
-          ImageClassificationMetrics.MacroPrecision,
-          ImageClassificationMetrics.MacroRecall
+          MulticlassClassificationMetrics.Accuracy,
+          MulticlassClassificationMetrics.MacroF1,
+          MulticlassClassificationMetrics.MacroPrecision,
+          MulticlassClassificationMetrics.MacroRecall
         ];
-      } else {
-        defaultSelectedMetrics = [MulticlassClassificationMetrics.Accuracy];
       }
     } else if (
       this.context.dataset.task_type ===
diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts
index 073bc0364f..f4a4d32751 100644
--- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts
+++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts
@@ -8,7 +8,6 @@ import {
   ErrorCohort,
   HighchartsNull,
   ILabeledStatistic,
-  ImageClassificationMetrics,
   MulticlassClassificationMetrics,
   MultilabelMetrics,
   ObjectDetectionMetrics,
@@ -258,74 +257,63 @@ export function getSelectableMetrics(
     taskType === DatasetTaskType.ImageClassification
   ) {
     if (isMulticlass) {
-      if (taskType === DatasetTaskType.ImageClassification) {
-        selectableMetrics.push(
-          {
-            description:
-              localization.ModelAssessment.ModelOverview.metrics.accuracy
-                .description,
-            key: ImageClassificationMetrics.Accuracy,
-            text: localization.ModelAssessment.ModelOverview.metrics.accuracy
-              .name
-          },
-          {
-            description:
-              localization.ModelAssessment.ModelOverview.metrics.precisionMacro
-                .description,
-            key: ImageClassificationMetrics.MacroPrecision,
-            text: localization.ModelAssessment.ModelOverview.metrics
-              .precisionMacro.name
-          },
-          {
-            description:
-              localization.ModelAssessment.ModelOverview.metrics.recallMacro
-                .description,
-            key: ImageClassificationMetrics.MacroRecall,
-            text: localization.ModelAssessment.ModelOverview.metrics.recallMacro
-              .name
-          },
-          {
-            description:
-              localization.ModelAssessment.ModelOverview.metrics.f1ScoreMacro
-                .description,
-            key: ImageClassificationMetrics.MacroF1,
-            text: localization.ModelAssessment.ModelOverview.metrics
-              .f1ScoreMacro.name
-          },
-          {
-            description:
-              localization.ModelAssessment.ModelOverview.metrics.precisionMicro
-                .description,
-            key: ImageClassificationMetrics.MicroPrecision,
-            text: localization.ModelAssessment.ModelOverview.metrics
-              .precisionMicro.name
-          },
-          {
-            description:
-              localization.ModelAssessment.ModelOverview.metrics.recallMicro
-                .description,
-            key: ImageClassificationMetrics.MicroRecall,
-            text: localization.ModelAssessment.ModelOverview.metrics.recallMicro
-              .name
-          },
-          {
-            description:
-              localization.ModelAssessment.ModelOverview.metrics.f1ScoreMicro
-                .description,
-            key: ImageClassificationMetrics.MicroF1,
-            text: localization.ModelAssessment.ModelOverview.metrics
-              .f1ScoreMicro.name
-          }
-        );
-      } else {
-        selectableMetrics.push({
+      selectableMetrics.push(
+        {
           description:
             localization.ModelAssessment.ModelOverview.metrics.accuracy
               .description,
           key: MulticlassClassificationMetrics.Accuracy,
           text: localization.ModelAssessment.ModelOverview.metrics.accuracy.name
-        });
-      }
+        },
+        {
+          description:
+            localization.ModelAssessment.ModelOverview.metrics.precisionMacro
+              .description,
+          key: MulticlassClassificationMetrics.MacroPrecision,
+          text: localization.ModelAssessment.ModelOverview.metrics
+            .precisionMacro.name
+        },
+        {
+          description:
+            localization.ModelAssessment.ModelOverview.metrics.recallMacro
+              .description,
+          key: MulticlassClassificationMetrics.MacroRecall,
+          text: localization.ModelAssessment.ModelOverview.metrics.recallMacro
+            .name
+        },
+        {
+          description:
+            localization.ModelAssessment.ModelOverview.metrics.f1ScoreMacro
+              .description,
+          key: MulticlassClassificationMetrics.MacroF1,
+          text: localization.ModelAssessment.ModelOverview.metrics.f1ScoreMacro
+            .name
+        },
+        {
+          description:
+            localization.ModelAssessment.ModelOverview.metrics.precisionMicro
+              .description,
+          key: MulticlassClassificationMetrics.MicroPrecision,
+          text: localization.ModelAssessment.ModelOverview.metrics
+            .precisionMicro.name
+        },
+        {
+          description:
+            localization.ModelAssessment.ModelOverview.metrics.recallMicro
+              .description,
+          key: MulticlassClassificationMetrics.MicroRecall,
+          text: localization.ModelAssessment.ModelOverview.metrics.recallMicro
+            .name
+        },
+        {
+          description:
+            localization.ModelAssessment.ModelOverview.metrics.f1ScoreMicro
+              .description,
+          key: MulticlassClassificationMetrics.MicroF1,
+          text: localization.ModelAssessment.ModelOverview.metrics.f1ScoreMicro
+            .name
+        }
+      );
     } else {
       selectableMetrics.push(
         {