diff --git a/apps/dashboard/src/app/textApplications.ts b/apps/dashboard/src/app/textApplications.ts index cb7414ddf0..2545f5b3b4 100644 --- a/apps/dashboard/src/app/textApplications.ts +++ b/apps/dashboard/src/app/textApplications.ts @@ -13,6 +13,7 @@ import { emotion, emotionModelExplanationData } from "../model-assessment-text/__mock_data__/emotion"; +import { squad } from "../model-assessment-text/__mock_data__/squad"; import { IDataSet, @@ -56,6 +57,10 @@ export const textApplications: ITextApplications = { classDimension: 3, dataset: emotion, modelExplanationData: [emotionModelExplanationData] + } as IModelAssessmentDataSet, + squad: { + classDimension: 3, + dataset: squad } as IModelAssessmentDataSet }, versions: { "1": 1, "2:Static-View": 2 } diff --git a/apps/dashboard/src/model-assessment-text/__mock_data__/squad.ts b/apps/dashboard/src/model-assessment-text/__mock_data__/squad.ts new file mode 100644 index 0000000000..e773e0b509 --- /dev/null +++ b/apps/dashboard/src/model-assessment-text/__mock_data__/squad.ts @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { DatasetTaskType, IDataset } from "@responsible-ai/core-ui"; + +export const squad: IDataset = { + categorical_features: [], + class_names: undefined, + feature_names: [ + "context_positive_words", + "context_negative_words", + "context_negation_words", + "context_negated_entities", + "context_named_persons", + "context_sentence_length", + "question_positive_words", + "question_negative_words", + "question_negation_words", + "question_negated_entities", + "question_named_persons", + "question_sentence_length" + ], + features: [ + [42, 0, 0, 0, 2, 695, 3, 0, 0, 0, 0, 71], + [42, 0, 0, 0, 2, 695, 3, 0, 0, 0, 0, 49], + [42, 0, 0, 0, 2, 695, 5, 0, 0, 0, 1, 76], + [42, 0, 0, 0, 2, 695, 1, 0, 0, 0, 1, 33], + [42, 0, 0, 0, 2, 695, 4, 0, 0, 0, 0, 52] + ], + predicted_y: [ + "Saint Bernadette Soubirous", + "a copper statue of Christ", + "the Main Building", + "a Marian place of prayer and reflection", + "a golden statue of the Virgin Mary" + ], + target_column: "answers", + task_type: DatasetTaskType.QuestionAnswering, + true_y: [ + "Saint Bernadette Soubirous", + "a copper statue of Christ", + "the Main Building", + "a Marian place of prayer and reflection", + "a golden statue of the Virgin Mary" + ] +}; diff --git a/apps/widget/src/app/ModelAssessment.tsx b/apps/widget/src/app/ModelAssessment.tsx index 493a0273af..7877bc504d 100644 --- a/apps/widget/src/app/ModelAssessment.tsx +++ b/apps/widget/src/app/ModelAssessment.tsx @@ -32,6 +32,11 @@ export class ModelAssessment extends React.Component { | "requestExp" | "requestObjectDetectionMetrics" | "requestPredictions" + | "requestQuestionAnsweringMetrics" +<<<<<<< HEAD + | "requestQuestionAnsweringMetrics" +======= +>>>>>>> 8bdf8400 (python scripts placeholder for QA metrics) | "requestDebugML" | "requestMatrix" | "requestImportances" @@ -72,6 +77,38 @@ export class ModelAssessment extends React.Component { callBack.requestPredictions = async (data: any[]): Promise => { return callFlaskService(this.props.config, data, "/predict"); }; + callBack.requestQuestionAnsweringMetrics = async ( + trueY: number[][][], + predictedY: number[][][], + aggregateMethod: string, + className: string, + iouThresh: number + ): Promise => { + return callFlaskService( + this.props.config, + [trueY, predictedY, aggregateMethod, className, iouThresh], + "/get_question_answering_metrics" + ); + }; +<<<<<<< HEAD + callBack.requestPredictions = async (data: any[]): Promise => { + return callFlaskService(this.props.config, data, "/predict"); + }; + callBack.requestQuestionAnsweringMetrics = async ( + trueY: number[][][], + predictedY: number[][][], + aggregateMethod: string, + className: string, + iouThresh: number + ): Promise => { + return callFlaskService( + this.props.config, + [trueY, predictedY, aggregateMethod, className, iouThresh], + "/get_question_answering_metrics" + ); + }; +======= +>>>>>>> 8bdf8400 (python scripts placeholder for QA metrics) callBack.requestMatrix = async ( data: any[] ): Promise => { diff --git a/libs/core-ui/src/index.ts b/libs/core-ui/src/index.ts index 5a7238fb50..c67e37bb25 100644 --- a/libs/core-ui/src/index.ts +++ b/libs/core-ui/src/index.ts @@ -59,6 +59,7 @@ export * from "./lib/util/calculateConfusionMatrixData"; export * from "./lib/util/calculateLineData"; export * from "./lib/util/MultilabelStatisticsUtils"; export * from "./lib/util/ObjectDetectionStatisticsUtils"; +export * from "./lib/util/QuestionAnsweringStatisticsUtils"; export * from "./lib/util/StatisticsUtils"; export * from "./lib/util/string"; export * from "./lib/util/toScientific"; diff --git a/libs/core-ui/src/lib/Cohort/Cohort.ts b/libs/core-ui/src/lib/Cohort/Cohort.ts index 339abe3e99..3ef7a217ae 100644 --- a/libs/core-ui/src/lib/Cohort/Cohort.ts +++ b/libs/core-ui/src/lib/Cohort/Cohort.ts @@ -18,7 +18,7 @@ export class Cohort { private readonly cohortIndex: number; private cachedAverageImportance: number[] | undefined; private cachedTransposedLocalFeatureImportances: number[][] | undefined; - private currentSortKey: string | undefined; + private currentSortKey: undefined | string; private currentSortReversed = false; public constructor( @@ -96,7 +96,14 @@ export class Cohort { } public getRow(index: number): { [key: string]: number } { - return { ...this.jointDataset.dataDict?.[index] }; + const dataDict = this.jointDataset.dataDict?.[index]; + const convertedDataDict: { [key: string]: number } = {}; + if (dataDict) { + for (const key in dataDict) { + convertedDataDict[key] = Number(dataDict[key]); + } + } + return convertedDataDict; } public sort( @@ -202,7 +209,7 @@ export class Cohort { } private filterRow( - row: { [key: string]: number }, + row: { [key: string]: number | string }, filters: IFilter[] ): boolean { return filters @@ -221,9 +228,9 @@ export class Cohort { case FilterMethods.LessThanEqualTo: return rowVal <= filter.arg[0]; case FilterMethods.Includes: - return (filter.arg as number[]).includes(rowVal); + return (filter.arg as number[]).includes(Number(rowVal)); case FilterMethods.Excludes: - return !(filter.arg as number[]).includes(rowVal); + return !(filter.arg as number[]).includes(Number(rowVal)); case FilterMethods.InTheRangeOf: return rowVal >= filter.arg[0] && rowVal <= filter.arg[1]; default: diff --git a/libs/core-ui/src/lib/Cohort/ErrorCohort.ts b/libs/core-ui/src/lib/Cohort/ErrorCohort.ts index 9c79c2e53d..6fafe64a27 100644 --- a/libs/core-ui/src/lib/Cohort/ErrorCohort.ts +++ b/libs/core-ui/src/lib/Cohort/ErrorCohort.ts @@ -70,7 +70,7 @@ export class ErrorCohort { } private updateStatsFromData( - filteredData: Array<{ [key: string]: number }>, + filteredData: Array<{ [key: string]: string | number }>, jointDataset: JointDataset ): ErrorCohortStats { // Calculate various cohort and global stats diff --git a/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx b/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx index 23e1a9de90..5d20101edf 100644 --- a/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx +++ b/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx @@ -145,6 +145,27 @@ export interface IModelAssessmentContext { abortSignal: AbortSignal ) => Promise) | undefined; + requestQuestionAnsweringMetrics?: + | (( + trueY: number[][][], + predictedY: number[][][], + aggregateMethod: string, + className: string, + iouThresh: number + ) => Promise) + | undefined; +<<<<<<< HEAD + requestQuestionAnsweringMetrics?: + | (( + trueY: number[][][], + predictedY: number[][][], + aggregateMethod: string, + className: string, + iouThresh: number + ) => Promise) + | undefined; +======= +>>>>>>> 8bdf8400 (python scripts placeholder for QA metrics) requestSplinePlotDistribution?: ( request: any, abortSignal: AbortSignal diff --git a/libs/core-ui/src/lib/Interfaces/IDataset.ts b/libs/core-ui/src/lib/Interfaces/IDataset.ts index 30ae0d71e2..524ef5114b 100644 --- a/libs/core-ui/src/lib/Interfaces/IDataset.ts +++ b/libs/core-ui/src/lib/Interfaces/IDataset.ts @@ -12,13 +12,14 @@ export enum DatasetTaskType { MultilabelTextClassification = "multilabel_text_classification", MultilabelImageClassification = "multilabel_image_classification", Forecasting = "forecasting", - ObjectDetection = "object_detection" + ObjectDetection = "object_detection", + QuestionAnswering = "question_answering" } export interface IDataset { task_type: DatasetTaskType; - true_y: number[] | number[][]; - predicted_y?: number[] | number[][]; + true_y: number[] | number[][] | string[]; + predicted_y?: number[] | number[][] | string[]; probability_y?: number[][]; features: unknown[][]; feature_names: string[]; diff --git a/libs/core-ui/src/lib/Interfaces/IExplanationContext.ts b/libs/core-ui/src/lib/Interfaces/IExplanationContext.ts index b04419fcbe..ba0ebfab3f 100644 --- a/libs/core-ui/src/lib/Interfaces/IExplanationContext.ts +++ b/libs/core-ui/src/lib/Interfaces/IExplanationContext.ts @@ -15,7 +15,8 @@ export enum ModelTypes { TextBinary = "textbinary", TextMulticlass = "textmulticlass", TextMultilabel = "textmultilabel", - ObjectDetection = "objectdetection" + ObjectDetection = "objectdetection", + QuestionAnswering = "questionanswering" } export interface IExplanationContext { diff --git a/libs/core-ui/src/lib/Interfaces/IModelExplanationData.ts b/libs/core-ui/src/lib/Interfaces/IModelExplanationData.ts index dc6021f810..4ab86e1182 100644 --- a/libs/core-ui/src/lib/Interfaces/IModelExplanationData.ts +++ b/libs/core-ui/src/lib/Interfaces/IModelExplanationData.ts @@ -6,7 +6,7 @@ import { IPrecomputedExplanations } from "./ExplanationInterfaces"; export interface IModelExplanationData { modelClass?: ModelClass; method?: Method; - predictedY?: number[] | number[][]; + predictedY?: number[] | number[][] | string[]; probabilityY?: number[][]; explanationMethod?: string; precomputedExplanations?: IPrecomputedExplanations; diff --git a/libs/core-ui/src/lib/Interfaces/VisionExplanationInterfaces.ts b/libs/core-ui/src/lib/Interfaces/VisionExplanationInterfaces.ts index d42b37c85d..b10e9b45b4 100644 --- a/libs/core-ui/src/lib/Interfaces/VisionExplanationInterfaces.ts +++ b/libs/core-ui/src/lib/Interfaces/VisionExplanationInterfaces.ts @@ -4,8 +4,8 @@ import { DatasetTaskType } from "./IDataset"; export interface IVisionExplanationDashboardData { task_type: DatasetTaskType; - true_y: number[] | number[][]; - predicted_y: number[] | number[][]; + true_y: number[] | number[][] | string[]; + predicted_y: number[] | number[][] | string[]; features?: unknown[][]; feature_names?: string[]; class_names: string[]; diff --git a/libs/core-ui/src/lib/util/JointDataset.ts b/libs/core-ui/src/lib/util/JointDataset.ts index 35b76c4692..bd480af4f7 100644 --- a/libs/core-ui/src/lib/util/JointDataset.ts +++ b/libs/core-ui/src/lib/util/JointDataset.ts @@ -73,6 +73,7 @@ export class JointDataset { // these properties should only be accessed by Cohort class, // which enables independent filtered views of this data public dataDict: Array<{ [key: string]: number }> | undefined; + public strDataDict: Array<{ [key: string]: string }> | undefined; public binDict: { [key: string]: number[] | undefined } = {}; private readonly _modelMeta: IExplanationModelMetadata; @@ -666,7 +667,8 @@ export class JointDataset { } private updateMetaDataDict( - values: number[] | number[][], + values: number[] | number[][] | string[], + values: number[] | number[][] | string[], metadata: IExplanationModelMetadata, labelColName: string, abbridgedLabel: string, @@ -683,7 +685,11 @@ export class JointDataset { } }); } else if (this.dataDict) { - this.dataDict[index][labelColName] = val; + if (typeof val !== "string") { + this.dataDict[index][labelColName] = val; + } else if (this.strDataDict) { + this.strDataDict[index][labelColName] = val; + } } }); for (let i = 0; i < this.numLabels; i++) { @@ -733,6 +739,13 @@ export class JointDataset { if (arr === undefined) { return; } + if (this.strDataDict === undefined) { + this.strDataDict = Array.from({ length: arr.length }).map((_, index) => { + const dict = {}; + dict[JointDataset.IndexLabel] = index; + return dict; + }); + } if (this.dataDict !== undefined) { if (this.dataDict.length !== arr.length) { throw new Error( diff --git a/libs/core-ui/src/lib/util/JointDatasetUtils.ts b/libs/core-ui/src/lib/util/JointDatasetUtils.ts index e59a4e2abd..6bfeda411a 100644 --- a/libs/core-ui/src/lib/util/JointDatasetUtils.ts +++ b/libs/core-ui/src/lib/util/JointDatasetUtils.ts @@ -14,9 +14,9 @@ import { AxisTypes } from "./IGenericChartProps"; export interface IJointDatasetArgs { dataset?: any[][]; - predictedY?: number[] | number[][]; + predictedY?: number[] | number[][] | string[]; predictedProbabilities?: number[][]; - trueY?: number[] | number[][]; + trueY?: number[] | number[][] | string[]; localExplanations?: | IMultiClassLocalFeatureImportance | ISingleClassLocalFeatureImportance; diff --git a/libs/core-ui/src/lib/util/MultilabelStatisticsUtils.ts b/libs/core-ui/src/lib/util/MultilabelStatisticsUtils.ts index a9aed980a0..bcbbf64161 100644 --- a/libs/core-ui/src/lib/util/MultilabelStatisticsUtils.ts +++ b/libs/core-ui/src/lib/util/MultilabelStatisticsUtils.ts @@ -11,8 +11,13 @@ import { import { JointDataset } from "./JointDataset"; export enum MultilabelMetrics { + BleuScore = "bleuScore", ExactMatchRatio = "exactMatchRatio", - HammingScore = "hammingScore" + HammingScore = "hammingScore", + F1Score = "f1Score", + + MeteorScore = "meteorScore", + RougeScore = "rougeScore" } export const generateMultilabelStats: ( @@ -50,6 +55,10 @@ export const generateMultilabelStats: ( hammingScore = hammingScore / numLabels; const sum = matchingLabels.reduce((prev, curr) => prev + curr, 0); const exactMatchRatio = sum / (numLabels * selectionArray.length); + const meteorScore = 0; + const f1Score = 0; + const rougeScore = 0; + const bleuScore = 0; return [ { @@ -62,6 +71,26 @@ export const generateMultilabelStats: ( label: localization.Interpret.Statistics.exactMatchRatio, stat: exactMatchRatio }, + { + key: MultilabelMetrics.MeteorScore, + label: localization.Interpret.Statistics.meteorScore, + stat: meteorScore + }, + { + key: MultilabelMetrics.F1Score, + label: localization.Interpret.Statistics.f1Score, + stat: f1Score + }, + { + key: MultilabelMetrics.BleuScore, + label: localization.Interpret.Statistics.bleuScore, + stat: bleuScore + }, + { + key: MultilabelMetrics.RougeScore, + label: localization.Interpret.Statistics.rougeScore, + stat: rougeScore + }, { key: MultilabelMetrics.HammingScore, label: localization.Interpret.Statistics.hammingScore, diff --git a/libs/core-ui/src/lib/util/QuestionAnsweringStatisticsUtils.ts b/libs/core-ui/src/lib/util/QuestionAnsweringStatisticsUtils.ts new file mode 100644 index 0000000000..a631534e48 --- /dev/null +++ b/libs/core-ui/src/lib/util/QuestionAnsweringStatisticsUtils.ts @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { localization } from "@responsible-ai/localization"; + +import { + ILabeledStatistic, + TotalCohortSamples +} from "../Interfaces/IStatistic"; + +import { JointDataset } from "./JointDataset"; + +export enum QuestionAnsweringMetrics { + BleuScore = "bleuScore", + ExactMatchRatio = "exactMatchRatio", + F1Score = "f1Score", + MeteorScore = "meteorScore", + RougeScore = "rougeScore" + F1Score = "f1Score", + MeteorScore = "meteorScore", + RougeScore = "rougeScore" +} + +function getf1Score(actual: string[], predicted: string[]): number { + const truePositives = actual.filter((value) => + predicted.includes(value) + ).length; + const falsePositives = predicted.filter( + (value) => !actual.includes(value) + ).length; + const falseNegatives = actual.filter( + (value) => !predicted.includes(value) + ).length; + + const precision = truePositives / (truePositives + falsePositives); + const recall = truePositives / (truePositives + falseNegatives); + + return 2 * ((precision * recall) / (precision + recall)); +} + +export const generateQuestionAnsweringStats: ( + jointDataset: JointDataset, + selectionIndexes: number[][] +) => ILabeledStatistic[][] = ( + jointDataset: JointDataset, + selectionIndexes: number[][] +): ILabeledStatistic[][] => { + return selectionIndexes.map((selectionArray) => { + const matchingLabels = []; + const count = selectionArray.length; + let trueYs: string[] = []; + let predYs: string[] = []; + if (jointDataset.strDataDict) { + trueYs = jointDataset.strDataDict.map( + (row) => row[JointDataset.TrueYLabel] + ); + predYs = jointDataset.strDataDict.map( + (row) => row[JointDataset.PredictedYLabel] + ); + } + + const trueYSubset = selectionArray.map((i) => trueYs[i]); + const predYSubset = selectionArray.map((i) => predYs[i]); + matchingLabels.push( + trueYSubset.filter((trueY, index) => trueY === predYSubset[index]).length + ); + + const meteorScore = 0; + const rougeScore = 0; + const bleuScore = 0; + const sum = matchingLabels.reduce((prev, curr) => prev + curr, 0); + const exactMatchRatio = sum / selectionArray.length; + + const f1Score = getf1Score( + jointDataset.unwrap(JointDataset.TrueYLabel), + jointDataset.unwrap(JointDataset.PredictedYLabel) + ); + + const meteorScore = 0; + const rougeScore = 0; + const bleuScore = 0; + + return [ + { + key: TotalCohortSamples, + label: localization.Interpret.Statistics.samples, + stat: count + }, + { + key: QuestionAnsweringMetrics.ExactMatchRatio, + label: localization.Interpret.Statistics.exactMatchRatio, + stat: exactMatchRatio + }, + { + key: QuestionAnsweringMetrics.F1Score, + label: localization.Interpret.Statistics.f1Score, + stat: f1Score + }, + { + key: QuestionAnsweringMetrics.MeteorScore, + label: localization.Interpret.Statistics.meteorScore, + stat: meteorScore + }, + { + key: QuestionAnsweringMetrics.BleuScore, + label: localization.Interpret.Statistics.bleuScore, + stat: bleuScore + }, + { + key: QuestionAnsweringMetrics.RougeScore, + label: localization.Interpret.Statistics.rougeScore, + stat: rougeScore + } + ]; + }); +}; diff --git a/libs/core-ui/src/lib/util/StatisticsUtils.ts b/libs/core-ui/src/lib/util/StatisticsUtils.ts index af331d2308..3cc3ad6864 100644 --- a/libs/core-ui/src/lib/util/StatisticsUtils.ts +++ b/libs/core-ui/src/lib/util/StatisticsUtils.ts @@ -21,6 +21,7 @@ import { } from "./JointDatasetUtils"; import { generateMultilabelStats } from "./MultilabelStatisticsUtils"; import { generateObjectDetectionStats } from "./ObjectDetectionStatisticsUtils"; +import { generateQuestionAnsweringStats } from "./QuestionAnsweringStatisticsUtils"; export enum BinaryClassificationMetrics { Accuracy = "accuracy", @@ -260,6 +261,9 @@ export const generateMetrics: ( ) { return generateMultilabelStats(jointDataset, selectionIndexes); } + if (modelType === ModelTypes.QuestionAnswering) { + return generateQuestionAnsweringStats(jointDataset, selectionIndexes); + } const trueYs = jointDataset.unwrap(JointDataset.TrueYLabel); const predYs = jointDataset.unwrap(JointDataset.PredictedYLabel); if (modelType === ModelTypes.Regression) { diff --git a/libs/core-ui/src/lib/util/datasetUtils/getColumnRanges.ts b/libs/core-ui/src/lib/util/datasetUtils/getColumnRanges.ts index 5f7d959393..6fdbc70289 100644 --- a/libs/core-ui/src/lib/util/datasetUtils/getColumnRanges.ts +++ b/libs/core-ui/src/lib/util/datasetUtils/getColumnRanges.ts @@ -134,7 +134,11 @@ function getRegressionErrorFeatureRange( if (Array.isArray(trueY) || Array.isArray(predictedY)) { return; } - regressionErrors.push(Math.abs(trueY - predictedY)); + if (typeof trueY !== "string" && typeof predictedY !== "string") { + regressionErrors.push(Math.abs(trueY - predictedY)); + } else { + regressionErrors.push(0); + } } return { max: _.max(regressionErrors) || 0, @@ -151,7 +155,7 @@ function getRegressionErrorFeatureRange( function getRange( dataset: IDataset, modelType: ModelTypes, - values: number[] | number[][], + values: number[] | number[][] | string[], property: string, ranges: { [key: string]: IColumnRange } ): void { @@ -163,7 +167,11 @@ function getRange( // this for loop is only to let it make sure values is a 1D array, so it can be used with _.max and _.min values.forEach((value) => { if (!Array.isArray(value)) { - numbers.push(value); + if (typeof value !== "string") { + numbers.push(value); + } else { + numbers.push(0); + } } }); ranges[property] = { diff --git a/libs/counterfactuals/src/util/generatePlotlyProps.ts b/libs/counterfactuals/src/util/generatePlotlyProps.ts index 71c3090d44..cebce468c0 100644 --- a/libs/counterfactuals/src/util/generatePlotlyProps.ts +++ b/libs/counterfactuals/src/util/generatePlotlyProps.ts @@ -106,7 +106,7 @@ export function generatePlotlyProps( } function generateDataTrace( - dictionary: Array<{ [key: string]: number }>, + dictionary: Array<{ [key: string]: string | number }>, chartProps: IGenericChartProps, trace: IData, jointDataset: JointDataset diff --git a/libs/dataset-explorer/src/lib/TableView/TableView.tsx b/libs/dataset-explorer/src/lib/TableView/TableView.tsx index db0504a250..13b8f97851 100644 --- a/libs/dataset-explorer/src/lib/TableView/TableView.tsx +++ b/libs/dataset-explorer/src/lib/TableView/TableView.tsx @@ -265,9 +265,11 @@ export class TableView extends React.Component< filteredDataRows = this.props.selectedCohort.cohort.filteredData; const numRows: number = filteredDataRows.length; - const indices = filteredDataRows.map((row: { [key: string]: number }) => { - return row[JointDataset.IndexLabel] as number; - }); + const indices = filteredDataRows.map( + (row: { [key: string]: string | number }) => { + return row[JointDataset.IndexLabel] as number; + } + ); const rows = constructRows( filteredDataRows, diff --git a/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/TabularDataView/TabularDataView.tsx b/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/TabularDataView/TabularDataView.tsx index 6300f5abf7..97d608b2cc 100644 --- a/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/TabularDataView/TabularDataView.tsx +++ b/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/TabularDataView/TabularDataView.tsx @@ -213,7 +213,9 @@ export class TabularDataView extends React.Component< this.props.setWhatIfDatapoint?.(item[0] as number); }; - private tabularDataFilter = (row: { [key: string]: number }): boolean => { + private tabularDataFilter = (row: { + [key: string]: string | number; + }): boolean => { switch (this.props.dataView) { case DataViewKeys.CorrectInstances: { if ( diff --git a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Interfaces/IVisionExplanationDashboardProps.ts b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Interfaces/IVisionExplanationDashboardProps.ts index d20dece04d..ece235e7cd 100644 --- a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Interfaces/IVisionExplanationDashboardProps.ts +++ b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Interfaces/IVisionExplanationDashboardProps.ts @@ -23,6 +23,20 @@ export interface IVisionExplanationDashboardProps { iouThresh: number, abortSignal: AbortSignal ) => Promise; + requestQuestionAnsweringMetrics?: ( + trueY: number[][][], + predictedY: number[][][], + aggregateMethod: string, + className: string, + iouThresh: number + ) => Promise; + requestQuestionAnsweringMetrics?: ( + trueY: number[][][], + predictedY: number[][][], + aggregateMethod: string, + className: string, + iouThresh: number + ) => Promise; selectedCohort: ErrorCohort; setSelectedCohort: (cohort: ErrorCohort) => void; } diff --git a/libs/interpret-vision/src/lib/VisionExplanationDashboard/VisionExplanationDashboardHelper.ts b/libs/interpret-vision/src/lib/VisionExplanationDashboard/VisionExplanationDashboardHelper.ts index cd66dc31ec..828775d2af 100644 --- a/libs/interpret-vision/src/lib/VisionExplanationDashboard/VisionExplanationDashboardHelper.ts +++ b/libs/interpret-vision/src/lib/VisionExplanationDashboard/VisionExplanationDashboardHelper.ts @@ -35,7 +35,7 @@ export const defaultImageSizes = { }; export function mapClassNames( - labels: number[] | number[][], + labels: number[] | number[][] | string[], classNames: string[] ): string[] | string[][] { if (Array.isArray(labels[0])) { @@ -134,8 +134,8 @@ export function getItems( > { const indices = new Set( props.selectedCohort.cohort.filteredData.map( - (row: { [key: string]: number }) => { - return row[JointDataset.IndexLabel] as number; + (row: { [key: string]: string | number }) => { + return row[JointDataset.IndexLabel] as string | number; } ) ); diff --git a/libs/interpret/src/lib/MLIDashboard/Controls/WhatIfTab/ExistingPredictionLabels.tsx b/libs/interpret/src/lib/MLIDashboard/Controls/WhatIfTab/ExistingPredictionLabels.tsx index c2bc5f05d6..bd744a3bd3 100644 --- a/libs/interpret/src/lib/MLIDashboard/Controls/WhatIfTab/ExistingPredictionLabels.tsx +++ b/libs/interpret/src/lib/MLIDashboard/Controls/WhatIfTab/ExistingPredictionLabels.tsx @@ -45,8 +45,8 @@ export class ExistingPredictionLabels extends React.Component, + dictionary: Array<{ [key: string]: string | number }>, chartProps: IGenericChartProps, trace: IData ): void { diff --git a/libs/localization/src/lib/en.json b/libs/localization/src/lib/en.json index 0172e8accc..6fa0a0d7a4 100644 --- a/libs/localization/src/lib/en.json +++ b/libs/localization/src/lib/en.json @@ -1173,11 +1173,14 @@ "_rSquared.comment": "the coefficient of determination, see https://en.wikipedia.org/wiki/Coefficient_of_determination", "_recall.comment": "computed recall of model, see https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers", "accuracy": "Accuracy: {0}", + "bleuScore": "Bleu score: {0}", "exactMatchRatio": "Exact match ratio: {0}", + "rougeScore": "Rouge Score: {0}", "fnr": "False negative rate: {0}", "fpr": "False positive rate: {0}", "hammingScore": "Hamming score: {0}", "meanPrediction": "Mean prediction {0}", + "meteorScore": "Meteor Score: {0}", "mse": "Mean squared error: {0}", "precision": "Precision: {0}", "rSquared": "R²: {0}", @@ -1688,6 +1691,18 @@ "name": "Exact match ratio", "description": "The ratio of instances classified correctly for every label in multilabel task." }, + "meteorScore": { + "name": "Meteor Score", + "description": "The ratio of instances classified correctly for every label in multilabel task." + }, + "bleuScore": { + "name": "Bleu Score", + "description": "The ratio of instances classified correctly for every label in multilabel task." + }, + "rougeScore": { + "name": "Rouge Score", + "description": "The ratio of instances classified correctly for every label in multilabel task." + }, "hammingScore": { "name": "Hamming score", "description": "The average ratio of labels classified correctly among those classified as 1 in multilabel task." diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts index 7737468898..de69f9ff89 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts @@ -39,7 +39,7 @@ export function buildInitialModelAssessmentContext( props: IModelAssessmentDashboardProps ): IModelAssessmentDashboardState { const modelMetadata = buildModelMetadata(props); - const modelType = getModelTypeFromProps(props); + const modelType = getModelTypeFromProps(props, modelMetadata.classNames); const columnRanges = getColumnRanges(props.dataset, modelType); let localExplanations: @@ -157,7 +157,7 @@ function buildModelMetadata( props: IModelAssessmentDashboardProps ): IExplanationModelMetadata { let classNames = props.dataset.class_names; - const modelType = getModelTypeFromProps(props); + const modelType = getModelTypeFromProps(props, classNames); let featureNames = props.dataset.feature_names; let featureNamesAbridged: string[]; const maxLength = 18; diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx index 5fcdb8c022..59695bf588 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx @@ -34,6 +34,7 @@ import { TelemetryEventName, DatasetTaskType, ImageClassificationMetrics, + QuestionAnsweringMetrics, TotalCohortSamples } from "@responsible-ai/core-ui"; import { localization } from "@responsible-ai/localization"; @@ -152,7 +153,11 @@ export class ModelOverview extends React.Component< ) { defaultSelectedMetrics = [ MultilabelMetrics.ExactMatchRatio, - MultilabelMetrics.HammingScore + MultilabelMetrics.HammingScore, + MultilabelMetrics.MeteorScore, + MultilabelMetrics.BleuScore, + MultilabelMetrics.F1Score, + MultilabelMetrics.RougeScore ]; } else if ( this.context.dataset.task_type === DatasetTaskType.ObjectDetection @@ -162,6 +167,13 @@ export class ModelOverview extends React.Component< ObjectDetectionMetrics.AveragePrecision, ObjectDetectionMetrics.AverageRecall ]; + } else if ( + this.context.dataset.task_type === DatasetTaskType.QuestionAnswering + ) { + defaultSelectedMetrics = [ + QuestionAnsweringMetrics.ExactMatchRatio, + QuestionAnsweringMetrics.F1Score + ]; } else { // task_type === "regression" defaultSelectedMetrics = [ diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts index ab191bcd52..6fbcc62e80 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts @@ -397,6 +397,35 @@ export function getSelectableMetrics( text: localization.ModelAssessment.ModelOverview.metrics.exactMatchRatio .name }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.meteorScore + .description, + key: MultilabelMetrics.MeteorScore, + text: localization.ModelAssessment.ModelOverview.metrics.meteorScore + .name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.f1Score + .description, + key: MultilabelMetrics.F1Score, + text: localization.ModelAssessment.ModelOverview.metrics.f1Score.name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.bleuScore + .description, + key: MultilabelMetrics.BleuScore, + text: localization.ModelAssessment.ModelOverview.metrics.bleuScore.name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.rougeScore + .description, + key: MultilabelMetrics.RougeScore, + text: localization.ModelAssessment.ModelOverview.metrics.rougeScore.name + }, { description: localization.ModelAssessment.ModelOverview.metrics.hammingScore @@ -433,6 +462,46 @@ export function getSelectableMetrics( .name } ); + } else if (taskType === DatasetTaskType.QuestionAnswering) { + selectableMetrics.push( + { + description: + localization.ModelAssessment.ModelOverview.metrics.exactMatchRatio + .description, + key: MultilabelMetrics.ExactMatchRatio, + text: localization.ModelAssessment.ModelOverview.metrics.exactMatchRatio + .name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.meteorScore + .description, + key: MultilabelMetrics.MeteorScore, + text: localization.ModelAssessment.ModelOverview.metrics.meteorScore + .name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.f1Score + .description, + key: MultilabelMetrics.F1Score, + text: localization.ModelAssessment.ModelOverview.metrics.f1Score.name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.bleuScore + .description, + key: MultilabelMetrics.BleuScore, + text: localization.ModelAssessment.ModelOverview.metrics.bleuScore.name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.rougeScore + .description, + key: MultilabelMetrics.RougeScore, + text: localization.ModelAssessment.ModelOverview.metrics.rougeScore.name + } + ); } else { // task_type === "regression" selectableMetrics.push( diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsView.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsView.tsx index ef86c79297..aa3454f52b 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsView.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsView.tsx @@ -182,6 +182,9 @@ export class TabsView extends React.PureComponent< requestObjectDetectionMetrics={ this.props.requestObjectDetectionMetrics } + requestQuestionAnsweringMetrics={ + this.props.requestQuestionAnsweringMetrics + } cohorts={this.props.cohorts} setSelectedCohort={this.props.setSelectedCohort} selectedCohort={this.props.selectedCohort} diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts index 6372f982a0..0fc92acf09 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts @@ -51,10 +51,18 @@ export interface ITabsViewProps { iouThresh: number, abortSignal: AbortSignal ) => Promise; + requestPredictions?: ( request: any[], abortSignal: AbortSignal ) => Promise; + requestQuestionAnsweringMetrics?: ( + trueY: number[][][], + predictedY: number[][][], + aggregateMethod: string, + className: string, + iouThresh: number + ) => Promise; requestDebugML?: ( request: any[], abortSignal: AbortSignal diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx index a5e3c0934d..716bdbef11 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx @@ -98,6 +98,8 @@ export class ModelAssessmentDashboard extends CohortBasedComponent< requestObjectDetectionMetrics: this.props.requestObjectDetectionMetrics, requestPredictions: this.props.requestPredictions, + requestQuestionAnsweringMetrics: + this.props.requestQuestionAnsweringMetrics, requestSplinePlotDistribution: this.props.requestSplinePlotDistribution, requestTestDataRow: this.props.requestTestDataRow, @@ -137,6 +139,9 @@ export class ModelAssessmentDashboard extends CohortBasedComponent< this.props.requestObjectDetectionMetrics } requestPredictions={this.props.requestPredictions} + requestQuestionAnsweringMetrics={ + this.props.requestQuestionAnsweringMetrics + } requestDebugML={this.props.requestDebugML} requestImportances={this.props.requestImportances} requestMatrix={this.props.requestMatrix} diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts index 68c1432b2f..304b8c46f5 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts @@ -122,6 +122,28 @@ export interface IModelAssessmentDashboardProps iouThresh: number, abortSignal: AbortSignal ) => Promise; + requestQuestionAnsweringMetrics?: ( + trueY: number[][][], + predictedY: number[][][], + aggregateMethod: string, + className: string, + iouThresh: number + ) => Promise; +<<<<<<< HEAD + requestQuestionAnsweringMetrics?: ( + trueY: number[][][], + predictedY: number[][][], + aggregateMethod: string, + className: string, + iouThresh: number + ) => Promise; + requestQuestionAnsweringMetrics?: ( + trueY: number[][][], + predictedY: number[][][], + aggregateMethod: string, + className: string, + iouThresh: number + ) => Promise; requestBubblePlotData?: ( filter: unknown[], compositeFilter: unknown[], diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts index cb18385e79..929e86e9e1 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts @@ -11,7 +11,8 @@ import { import { IModelAssessmentDashboardProps } from "../ModelAssessmentDashboardProps"; export function getModelTypeFromProps( - props: IModelAssessmentDashboardProps + props: IModelAssessmentDashboardProps, + classNames: string[] | undefined ): ModelTypes { const modelType: ModelTypes = ModelTypes.Multiclass; const classNames = props.dataset.class_names; @@ -50,5 +51,8 @@ export function getModelTypeFromProps( if (taskType === DatasetTaskType.ObjectDetection) { return ModelTypes.ObjectDetection; } + if (taskType === DatasetTaskType.QuestionAnswering) { + return ModelTypes.QuestionAnswering; + } return modelType; } diff --git a/raiwidgets/raiwidgets/responsibleai_dashboard.py b/raiwidgets/raiwidgets/responsibleai_dashboard.py index bdad46826f..520a3a4e90 100644 --- a/raiwidgets/raiwidgets/responsibleai_dashboard.py +++ b/raiwidgets/raiwidgets/responsibleai_dashboard.py @@ -92,3 +92,12 @@ def get_object_detection_metrics(): '/get_object_detection_metrics', methods=["POST"] ) + + def get_question_answering_metrics(): + data = request.get_json(force=True) + return jsonify(self.input.get_question_answering_metrics(data)) + self.add_url_rule( + get_exp, + '/get_question_answering_metrics', + methods=["POST"] + ) diff --git a/raiwidgets/raiwidgets/responsibleai_dashboard_input.py b/raiwidgets/raiwidgets/responsibleai_dashboard_input.py index d4bd148dac..d64df5ad4a 100644 --- a/raiwidgets/raiwidgets/responsibleai_dashboard_input.py +++ b/raiwidgets/raiwidgets/responsibleai_dashboard_input.py @@ -353,3 +353,81 @@ def get_object_detection_metrics(self, post_data): "inner error: {}".format(e_str), WidgetRequestResponseConstants.data: [] } + + def get_question_answering_metrics(self, post_data): + """Flask endpoint function to get Model Overview metrics + for the Question Answering scenario. + + :param post_data: List of inputs in the order + [true_y, predicted_y, aggregate_method, class_name, iou_thresh]. + :type post_data: List + + :return: JSON/dict data response + :rtype: Dict[str, List] + """ + try: + true_y = post_data[0] + predicted_y = post_data[1] + aggregate_method = post_data[2] + class_name = post_data[3] + iou_thresh = post_data[4] + exp = self._analysis.compute_question_answering_metrics( + true_y, + predicted_y, + aggregate_method, + class_name, + iou_thresh + ) + return { + WidgetRequestResponseConstants.data: exp + } + except Exception as e: + print(e) + traceback.print_exc() + e_str = _format_exception(e) + return { + WidgetRequestResponseConstants.error: + "Failed to get Question Answering Model Overview metrics," + "inner error: {}".format(e_str), + WidgetRequestResponseConstants.data: [] +<<<<<<< HEAD + } + def get_question_answering_metrics(self, post_data): + """Flask endpoint function to get Model Overview metrics + for the Question Answering scenario. + + :param post_data: List of inputs in the order + [true_y, predicted_y, aggregate_method, class_name, iou_thresh]. + :type post_data: List + + :return: JSON/dict data response + :rtype: Dict[str, List] + """ + try: + true_y = post_data[0] + predicted_y = post_data[1] + aggregate_method = post_data[2] + class_name = post_data[3] + iou_thresh = post_data[4] + exp = self._analysis.compute_question_answering_metrics( + true_y, + predicted_y, + aggregate_method, + class_name, + iou_thresh + ) + return { + WidgetRequestResponseConstants.data: exp + } + except Exception as e: + print(e) + traceback.print_exc() + e_str = _format_exception(e) + return { + WidgetRequestResponseConstants.error: + "Failed to get Question Answering Model Overview metrics," + "inner error: {}".format(e_str), + WidgetRequestResponseConstants.data: [] +======= +>>>>>>> 8bdf8400 (python scripts placeholder for QA metrics) + } \ No newline at end of file