Skip to content

Commit

Permalink
adding question answering RAI text dashboard metrics with umass team
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft authored and mehektulsyan committed Apr 24, 2023
1 parent 0af6cef commit 9c9adb2
Show file tree
Hide file tree
Showing 35 changed files with 568 additions and 36 deletions.
5 changes: 5 additions & 0 deletions apps/dashboard/src/app/textApplications.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
emotion,
emotionModelExplanationData
} from "../model-assessment-text/__mock_data__/emotion";
import { squad } from "../model-assessment-text/__mock_data__/squad";

import {
IDataSet,
Expand Down Expand Up @@ -56,6 +57,10 @@ export const textApplications: ITextApplications = <const>{
classDimension: 3,
dataset: emotion,
modelExplanationData: [emotionModelExplanationData]
} as IModelAssessmentDataSet,
squad: {
classDimension: 3,
dataset: squad
} as IModelAssessmentDataSet
},
versions: { "1": 1, "2:Static-View": 2 }
Expand Down
46 changes: 46 additions & 0 deletions apps/dashboard/src/model-assessment-text/__mock_data__/squad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { DatasetTaskType, IDataset } from "@responsible-ai/core-ui";

export const squad: IDataset = {
categorical_features: [],
class_names: undefined,
feature_names: [
"context_positive_words",
"context_negative_words",
"context_negation_words",
"context_negated_entities",
"context_named_persons",
"context_sentence_length",
"question_positive_words",
"question_negative_words",
"question_negation_words",
"question_negated_entities",
"question_named_persons",
"question_sentence_length"
],
features: [
[42, 0, 0, 0, 2, 695, 3, 0, 0, 0, 0, 71],
[42, 0, 0, 0, 2, 695, 3, 0, 0, 0, 0, 49],
[42, 0, 0, 0, 2, 695, 5, 0, 0, 0, 1, 76],
[42, 0, 0, 0, 2, 695, 1, 0, 0, 0, 1, 33],
[42, 0, 0, 0, 2, 695, 4, 0, 0, 0, 0, 52]
],
predicted_y: [
"Saint Bernadette Soubirous",
"a copper statue of Christ",
"the Main Building",
"a Marian place of prayer and reflection",
"a golden statue of the Virgin Mary"
],
target_column: "answers",
task_type: DatasetTaskType.QuestionAnswering,
true_y: [
"Saint Bernadette Soubirous",
"a copper statue of Christ",
"the Main Building",
"a Marian place of prayer and reflection",
"a golden statue of the Virgin Mary"
]
};
37 changes: 37 additions & 0 deletions apps/widget/src/app/ModelAssessment.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ export class ModelAssessment extends React.Component<IModelAssessmentProps> {
| "requestExp"
| "requestObjectDetectionMetrics"
| "requestPredictions"
| "requestQuestionAnsweringMetrics"
<<<<<<< HEAD
| "requestQuestionAnsweringMetrics"
=======
>>>>>>> 8bdf8400 (python scripts placeholder for QA metrics)
| "requestDebugML"
| "requestMatrix"
| "requestImportances"
Expand Down Expand Up @@ -72,6 +77,38 @@ export class ModelAssessment extends React.Component<IModelAssessmentProps> {
callBack.requestPredictions = async (data: any[]): Promise<any[]> => {
return callFlaskService(this.props.config, data, "/predict");
};
callBack.requestQuestionAnsweringMetrics = async (
trueY: number[][][],
predictedY: number[][][],
aggregateMethod: string,
className: string,
iouThresh: number
): Promise<any[]> => {
return callFlaskService(
this.props.config,
[trueY, predictedY, aggregateMethod, className, iouThresh],
"/get_question_answering_metrics"
);
};
<<<<<<< HEAD
callBack.requestPredictions = async (data: any[]): Promise<any[]> => {
return callFlaskService(this.props.config, data, "/predict");
};
callBack.requestQuestionAnsweringMetrics = async (
trueY: number[][][],
predictedY: number[][][],
aggregateMethod: string,
className: string,
iouThresh: number
): Promise<any[]> => {
return callFlaskService(
this.props.config,
[trueY, predictedY, aggregateMethod, className, iouThresh],
"/get_question_answering_metrics"
);
};
=======
>>>>>>> 8bdf8400 (python scripts placeholder for QA metrics)
callBack.requestMatrix = async (
data: any[]
): Promise<IErrorAnalysisMatrix> => {
Expand Down
1 change: 1 addition & 0 deletions libs/core-ui/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ export * from "./lib/util/calculateConfusionMatrixData";
export * from "./lib/util/calculateLineData";
export * from "./lib/util/MultilabelStatisticsUtils";
export * from "./lib/util/ObjectDetectionStatisticsUtils";
export * from "./lib/util/QuestionAnsweringStatisticsUtils";
export * from "./lib/util/StatisticsUtils";
export * from "./lib/util/string";
export * from "./lib/util/toScientific";
Expand Down
17 changes: 12 additions & 5 deletions libs/core-ui/src/lib/Cohort/Cohort.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export class Cohort {
private readonly cohortIndex: number;
private cachedAverageImportance: number[] | undefined;
private cachedTransposedLocalFeatureImportances: number[][] | undefined;
private currentSortKey: string | undefined;
private currentSortKey: undefined | string;
private currentSortReversed = false;

public constructor(
Expand Down Expand Up @@ -96,7 +96,14 @@ export class Cohort {
}

public getRow(index: number): { [key: string]: number } {
return { ...this.jointDataset.dataDict?.[index] };
const dataDict = this.jointDataset.dataDict?.[index];
const convertedDataDict: { [key: string]: number } = {};
if (dataDict) {
for (const key in dataDict) {
convertedDataDict[key] = Number(dataDict[key]);
}
}
return convertedDataDict;
}

public sort(
Expand Down Expand Up @@ -202,7 +209,7 @@ export class Cohort {
}

private filterRow(
row: { [key: string]: number },
row: { [key: string]: number | string },
filters: IFilter[]
): boolean {
return filters
Expand All @@ -221,9 +228,9 @@ export class Cohort {
case FilterMethods.LessThanEqualTo:
return rowVal <= filter.arg[0];
case FilterMethods.Includes:
return (filter.arg as number[]).includes(rowVal);
return (filter.arg as number[]).includes(Number(rowVal));
case FilterMethods.Excludes:
return !(filter.arg as number[]).includes(rowVal);
return !(filter.arg as number[]).includes(Number(rowVal));
case FilterMethods.InTheRangeOf:
return rowVal >= filter.arg[0] && rowVal <= filter.arg[1];
default:
Expand Down
2 changes: 1 addition & 1 deletion libs/core-ui/src/lib/Cohort/ErrorCohort.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ export class ErrorCohort {
}

private updateStatsFromData(
filteredData: Array<{ [key: string]: number }>,
filteredData: Array<{ [key: string]: string | number }>,
jointDataset: JointDataset
): ErrorCohortStats {
// Calculate various cohort and global stats
Expand Down
21 changes: 21 additions & 0 deletions libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,27 @@ export interface IModelAssessmentContext {
abortSignal: AbortSignal
) => Promise<any[]>)
| undefined;
requestQuestionAnsweringMetrics?:
| ((
trueY: number[][][],
predictedY: number[][][],
aggregateMethod: string,
className: string,
iouThresh: number
) => Promise<any[]>)
| undefined;
<<<<<<< HEAD
requestQuestionAnsweringMetrics?:
| ((
trueY: number[][][],
predictedY: number[][][],
aggregateMethod: string,
className: string,
iouThresh: number
) => Promise<any[]>)
| undefined;
=======
>>>>>>> 8bdf8400 (python scripts placeholder for QA metrics)
requestSplinePlotDistribution?: (
request: any,
abortSignal: AbortSignal
Expand Down
7 changes: 4 additions & 3 deletions libs/core-ui/src/lib/Interfaces/IDataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ export enum DatasetTaskType {
MultilabelTextClassification = "multilabel_text_classification",
MultilabelImageClassification = "multilabel_image_classification",
Forecasting = "forecasting",
ObjectDetection = "object_detection"
ObjectDetection = "object_detection",
QuestionAnswering = "question_answering"
}

export interface IDataset {
task_type: DatasetTaskType;
true_y: number[] | number[][];
predicted_y?: number[] | number[][];
true_y: number[] | number[][] | string[];
predicted_y?: number[] | number[][] | string[];
probability_y?: number[][];
features: unknown[][];
feature_names: string[];
Expand Down
3 changes: 2 additions & 1 deletion libs/core-ui/src/lib/Interfaces/IExplanationContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ export enum ModelTypes {
TextBinary = "textbinary",
TextMulticlass = "textmulticlass",
TextMultilabel = "textmultilabel",
ObjectDetection = "objectdetection"
ObjectDetection = "objectdetection",
QuestionAnswering = "questionanswering"
}

export interface IExplanationContext {
Expand Down
2 changes: 1 addition & 1 deletion libs/core-ui/src/lib/Interfaces/IModelExplanationData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { IPrecomputedExplanations } from "./ExplanationInterfaces";
export interface IModelExplanationData {
modelClass?: ModelClass;
method?: Method;
predictedY?: number[] | number[][];
predictedY?: number[] | number[][] | string[];
probabilityY?: number[][];
explanationMethod?: string;
precomputedExplanations?: IPrecomputedExplanations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import { DatasetTaskType } from "./IDataset";
export interface IVisionExplanationDashboardData {
task_type: DatasetTaskType;
true_y: number[] | number[][];
predicted_y: number[] | number[][];
true_y: number[] | number[][] | string[];
predicted_y: number[] | number[][] | string[];
features?: unknown[][];
feature_names?: string[];
class_names: string[];
Expand Down
17 changes: 15 additions & 2 deletions libs/core-ui/src/lib/util/JointDataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ export class JointDataset {
// these properties should only be accessed by Cohort class,
// which enables independent filtered views of this data
public dataDict: Array<{ [key: string]: number }> | undefined;
public strDataDict: Array<{ [key: string]: string }> | undefined;
public binDict: { [key: string]: number[] | undefined } = {};

private readonly _modelMeta: IExplanationModelMetadata;
Expand Down Expand Up @@ -666,7 +667,8 @@ export class JointDataset {
}

private updateMetaDataDict(
values: number[] | number[][],
values: number[] | number[][] | string[],
values: number[] | number[][] | string[],
metadata: IExplanationModelMetadata,
labelColName: string,
abbridgedLabel: string,
Expand All @@ -683,7 +685,11 @@ export class JointDataset {
}
});
} else if (this.dataDict) {
this.dataDict[index][labelColName] = val;
if (typeof val !== "string") {
this.dataDict[index][labelColName] = val;
} else if (this.strDataDict) {
this.strDataDict[index][labelColName] = val;
}
}
});
for (let i = 0; i < this.numLabels; i++) {
Expand Down Expand Up @@ -733,6 +739,13 @@ export class JointDataset {
if (arr === undefined) {
return;
}
if (this.strDataDict === undefined) {
this.strDataDict = Array.from({ length: arr.length }).map((_, index) => {
const dict = {};
dict[JointDataset.IndexLabel] = index;
return dict;
});
}
if (this.dataDict !== undefined) {
if (this.dataDict.length !== arr.length) {
throw new Error(
Expand Down
4 changes: 2 additions & 2 deletions libs/core-ui/src/lib/util/JointDatasetUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import { AxisTypes } from "./IGenericChartProps";

export interface IJointDatasetArgs {
dataset?: any[][];
predictedY?: number[] | number[][];
predictedY?: number[] | number[][] | string[];
predictedProbabilities?: number[][];
trueY?: number[] | number[][];
trueY?: number[] | number[][] | string[];
localExplanations?:
| IMultiClassLocalFeatureImportance
| ISingleClassLocalFeatureImportance;
Expand Down
31 changes: 30 additions & 1 deletion libs/core-ui/src/lib/util/MultilabelStatisticsUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ import {
import { JointDataset } from "./JointDataset";

export enum MultilabelMetrics {
BleuScore = "bleuScore",
ExactMatchRatio = "exactMatchRatio",
HammingScore = "hammingScore"
HammingScore = "hammingScore",
F1Score = "f1Score",

MeteorScore = "meteorScore",
RougeScore = "rougeScore"
}

export const generateMultilabelStats: (
Expand Down Expand Up @@ -50,6 +55,10 @@ export const generateMultilabelStats: (
hammingScore = hammingScore / numLabels;
const sum = matchingLabels.reduce((prev, curr) => prev + curr, 0);
const exactMatchRatio = sum / (numLabels * selectionArray.length);
const meteorScore = 0;
const f1Score = 0;
const rougeScore = 0;
const bleuScore = 0;

return [
{
Expand All @@ -62,6 +71,26 @@ export const generateMultilabelStats: (
label: localization.Interpret.Statistics.exactMatchRatio,
stat: exactMatchRatio
},
{
key: MultilabelMetrics.MeteorScore,
label: localization.Interpret.Statistics.meteorScore,
stat: meteorScore
},
{
key: MultilabelMetrics.F1Score,
label: localization.Interpret.Statistics.f1Score,
stat: f1Score
},
{
key: MultilabelMetrics.BleuScore,
label: localization.Interpret.Statistics.bleuScore,
stat: bleuScore
},
{
key: MultilabelMetrics.RougeScore,
label: localization.Interpret.Statistics.rougeScore,
stat: rougeScore
},
{
key: MultilabelMetrics.HammingScore,
label: localization.Interpret.Statistics.hammingScore,
Expand Down
Loading

0 comments on commit 9c9adb2

Please sign in to comment.