Skip to content

Commit

Permalink
add multilabel support to RAI dashboard and multilabel text classific…
Browse files Browse the repository at this point in the history
…ation covid events dataset
  • Loading branch information
imatiach-msft committed Nov 28, 2022
1 parent 4c34d81 commit 69d7553
Show file tree
Hide file tree
Showing 16 changed files with 886 additions and 80 deletions.
684 changes: 684 additions & 0 deletions apps/dashboard/src/model-assessment-text/__mock_data__/covidevents.ts

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions libs/core-ui/src/lib/Interfaces/IDataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,22 @@ export enum DatasetTaskType {
Regression = "regression",
Classification = "classification",
ImageClassification = "image_classification",
TextClassification = "text_classification"
TextClassification = "text_classification",
MultilabelTextClassification = "multilabel_text_classification"
}

export interface IDataset {
task_type: DatasetTaskType;
true_y: number[];
predicted_y?: number[];
true_y: number[] | number[][];
predicted_y?: number[] | number[][];
probability_y?: number[][];
features: unknown[][];
feature_names: string[];
categorical_features: string[];
is_large_data_scenario?: boolean;
use_entire_test_data?: boolean;
class_names?: string[];
target_column?: string;
target_column?: string | string[];
data_balance_measures?: IDataBalanceMeasures;
feature_metadata?: IFeatureMetaData;
images?: string[];
Expand Down
2 changes: 1 addition & 1 deletion libs/core-ui/src/lib/Interfaces/IModelExplanationData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { IPrecomputedExplanations } from "./ExplanationInterfaces";
export interface IModelExplanationData {
modelClass?: ModelClass;
method?: Method;
predictedY?: number[];
predictedY?: number[] | number[][];
probabilityY?: number[][];
explanationMethod?: string;
precomputedExplanations?: IPrecomputedExplanations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
// Licensed under the MIT License.

export interface IVisionListItem {
[key: string]: string | number | boolean;
[key: string]: string | number | boolean | string[];
image: string;
predictedY: string;
trueY: string;
predictedY: string | string[];
trueY: string | string[];
index: number;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import { DatasetTaskType } from "./IDataset";
export interface IVisionExplanationDashboardData {
task_type: DatasetTaskType;
true_y: number[];
predicted_y: number[];
true_y: number[] | number[][];
predicted_y: number[] | number[][];
features?: unknown[][];
feature_names?: string[];
class_names: string[];
Expand Down
47 changes: 45 additions & 2 deletions libs/core-ui/src/lib/util/DatasetUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,28 @@ export function constructRows(
tableRow.push(colors[i]);
}
if (jointDataset.hasTrueY) {
pushRowData(tableRow, JointDataset.TrueYLabel, jointDataset, row);
if (jointDataset.numLabels > 1) {
pushMultilabelRowData(
tableRow,
JointDataset.TrueYLabel,
jointDataset,
row
);
} else {
pushRowData(tableRow, JointDataset.TrueYLabel, jointDataset, row);
}
}
if (jointDataset.hasPredictedY) {
pushRowData(tableRow, JointDataset.PredictedYLabel, jointDataset, row);
if (jointDataset.numLabels > 1) {
pushMultilabelRowData(
tableRow,
JointDataset.PredictedYLabel,
jointDataset,
row
);
} else {
pushRowData(tableRow, JointDataset.PredictedYLabel, jointDataset, row);
}
}
tableRow.push(...data);
rows.push(tableRow);
Expand Down Expand Up @@ -152,3 +170,28 @@ function pushRowData(
tableRow.push(row[property]);
}
}

function pushMultilabelRowData(
tableRow: any[],
property: string,
jointDataset: JointDataset,
row: { [key: string]: number }
): void {
const values = [];
for (let i = 0; i < jointDataset.numLabels; i++) {
const labelProp = property + i.toString();
const categories = jointDataset.metaDict[labelProp].sortedCategoricalValues;
if (jointDataset.metaDict[labelProp].isCategorical && categories) {
const value = categories[row[labelProp]];
if (value) {
values.push(value);
}
} else {
const value = row[labelProp];
if (value) {
values.push(value);
}
}
}
tableRow.push(values.join(","));
}
133 changes: 85 additions & 48 deletions libs/core-ui/src/lib/util/JointDataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export class JointDataset {
public predictionClassCount = 0;
public datasetRowCount = 0;
public localExplanationFeatureCount = 0;
public numLabels = 1;

// these properties should only be accessed by Cohort class,
// which enables independent filtered views of this data
Expand Down Expand Up @@ -137,30 +138,14 @@ export class JointDataset {
this.hasDataset = true;
}
if (args.predictedY) {
this.initializeDataDictIfNeeded(args.predictedY);
args.predictedY.forEach((val, index) => {
if (this.dataDict) {
this.dataDict[index][JointDataset.PredictedYLabel] = val;
}
});
this.metaDict[JointDataset.PredictedYLabel] = {
abbridgedLabel: localization.Interpret.ExplanationScatter.predictedY,
category: ColumnCategories.Outcome,
isCategorical: args.metadata.modelType !== ModelTypes.Regression,
label: localization.Interpret.ExplanationScatter.predictedY,
sortedCategoricalValues:
args.metadata.modelType !== ModelTypes.Regression
? args.metadata.classNames
: undefined,
treatAsCategorical: args.metadata.modelType !== ModelTypes.Regression
};
if (args.metadata.modelType === ModelTypes.Regression) {
this.metaDict[JointDataset.PredictedYLabel].featureRange = {
max: _.max(args.predictedY) || 0,
min: _.min(args.predictedY) || 0,
rangeType: RangeTypes.Numeric
};
}
this.updateMetaDataDict(
args.predictedY,
args.metadata,
JointDataset.PredictedYLabel,
localization.Interpret.ExplanationScatter.predictedY,
localization.Interpret.ExplanationScatter.predictedY,
args.targetColumn
);
this.hasPredictedY = true;
}
if (args.predictedProbabilities) {
Expand Down Expand Up @@ -204,30 +189,14 @@ export class JointDataset {
}
}
if (args.trueY) {
this.initializeDataDictIfNeeded(args.trueY);
args.trueY.forEach((val, index) => {
if (this.dataDict) {
this.dataDict[index][JointDataset.TrueYLabel] = val;
}
});
this.metaDict[JointDataset.TrueYLabel] = {
abbridgedLabel: localization.Interpret.ExplanationScatter.trueY,
category: ColumnCategories.Outcome,
isCategorical: args.metadata.modelType !== ModelTypes.Regression,
label: localization.Interpret.ExplanationScatter.trueY,
sortedCategoricalValues:
args.metadata.modelType !== ModelTypes.Regression
? args.metadata.classNames
: undefined,
treatAsCategorical: args.metadata.modelType !== ModelTypes.Regression
};
if (args.metadata.modelType === ModelTypes.Regression) {
this.metaDict[JointDataset.TrueYLabel].featureRange = {
max: _.max(args.trueY) || 0,
min: _.min(args.trueY) || 0,
rangeType: RangeTypes.Numeric
};
}
this.updateMetaDataDict(
args.trueY,
args.metadata,
JointDataset.TrueYLabel,
localization.Interpret.ExplanationScatter.trueY,
localization.Interpret.ExplanationScatter.trueY,
args.targetColumn
);
this.hasTrueY = true;
}
// include error columns if applicable
Expand Down Expand Up @@ -677,6 +646,74 @@ export class JointDataset {
return undefined;
}

private updateMetaDataDict(
values: number[] | number[][],
metadata: IExplanationModelMetadata,
labelColName: string,
abbridgedLabel: string,
label: string,
targetColumn?: string | string[]
): void {
this.initializeDataDictIfNeeded(values);
values.forEach((val, index) => {
if (Array.isArray(val)) {
this.numLabels = val.length;
val.forEach((subVal, subIndex) => {
if (this.dataDict) {
this.dataDict[index][labelColName + subIndex.toString()] = subVal;
}
});
} else {
if (this.dataDict) {
this.dataDict[index][labelColName] = val;
}
}
});
for (let i = 0; i < this.numLabels; i++) {
let labelColNameKey = labelColName;
let abbridgedLabelValue = abbridgedLabel;
let labelValue = label;
let singleLabelValues: number[] = [];
if (this.numLabels > 1) {
const labelIdxStr = i.toString();
labelColNameKey += labelIdxStr;
abbridgedLabelValue += labelIdxStr;
labelValue += labelIdxStr;
// check if values is a 2d array
const indexedValues = values[i];
if (Array.isArray(indexedValues)) {
singleLabelValues = indexedValues;
}
} else {
if (!Array.isArray(values)) {
singleLabelValues = values;
}
}
let categoricalValues =
metadata.modelType !== ModelTypes.Regression
? metadata.classNames
: undefined;
if (this.numLabels > 1 && Array.isArray(targetColumn)) {
categoricalValues = ["", targetColumn[i]];
}
this.metaDict[labelColNameKey] = {
abbridgedLabel: abbridgedLabelValue,
category: ColumnCategories.Outcome,
isCategorical: metadata.modelType !== ModelTypes.Regression,
label: labelValue,
sortedCategoricalValues: categoricalValues,
treatAsCategorical: metadata.modelType !== ModelTypes.Regression
};
if (metadata.modelType === ModelTypes.Regression) {
this.metaDict[labelColNameKey].featureRange = {
max: _.max(singleLabelValues) || 0,
min: _.min(singleLabelValues) || 0,
rangeType: RangeTypes.Numeric
};
}
}
}

private initializeDataDictIfNeeded(arr: any[]): void {
if (arr === undefined) {
return;
Expand Down
5 changes: 3 additions & 2 deletions libs/core-ui/src/lib/util/JointDatasetUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ import { AxisTypes } from "./IGenericChartProps";

export interface IJointDatasetArgs {
dataset?: any[][];
predictedY?: number[];
predictedY?: number[] | number[][];
predictedProbabilities?: number[][];
trueY?: number[];
trueY?: number[] | number[][];
localExplanations?:
| IMultiClassLocalFeatureImportance
| ISingleClassLocalFeatureImportance;
metadata: IExplanationModelMetadata;
featureMetaData?: IFeatureMetaData;
targetColumn?: string | string[];
}

export enum ColumnCategories {
Expand Down
5 changes: 4 additions & 1 deletion libs/counterfactuals/src/util/getOriginalData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ export function getOriginalData(
featureNames.forEach((f, index) => {
data[f] = dataPoint[index];
});
const targetLabel = dataset.target_column || "y";
const targetColumn = Array.isArray(dataset.target_column)
? dataset.target_column?.[0]
: dataset.target_column;
const targetLabel = targetColumn || "y";
data[targetLabel] = row[JointDataset.TrueYLabel];
return data;
}
Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,22 @@ export class DataCharacteristics extends React.Component<
): React.ReactElement => {
const imageDim = this.props.imageDim;
const classNames = dataCharacteristicsStyles();
const predictedY = item?.predictedY;
const indicatorStyle = mergeStyles(
classNames.indicator,
{ width: imageDim },
item?.predictedY === item?.trueY
predictedY === item?.trueY
? classNames.successIndicator
: classNames.errorIndicator
);
const alt = Array.isArray(predictedY) ? predictedY.join(",") : predictedY;
return !item ? (
<div />
) : (
<Stack className={classNames.tile}>
<Stack.Item style={{ height: imageDim, width: imageDim }}>
<Image
alt={item?.predictedY}
alt={alt}
src={`data:image/jpg;base64,${item?.image}`}
onClick={this.callbackWrapper(item)}
className={classNames.image}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ export class DataCharacteristicsRow extends React.Component<IDataCharacteristics
const listContainerStyle = mergeStyles(classNames.listContainer, {
height: imageDim + 30
});
const predictedY = this.props.list[0].predictedY;
const key = Array.isArray(predictedY) ? predictedY.join(",") : predictedY;
return (
<Stack>
<Stack.Item>
Expand All @@ -93,7 +95,7 @@ export class DataCharacteristicsRow extends React.Component<IDataCharacteristics
)}
<Stack.Item className={listContainerStyle}>
<List
key={this.props.list[0].predictedY}
key={key}
items={list}
onRenderCell={onRenderCell}
getPageHeight={getPageHeight}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@ export class Flyout extends React.Component<IFlyoutProps, IFlyoutState> {
const fieldNames = this.props.otherMetadataFieldNames;
const metadata: Array<Array<string | number | boolean>> = [];
fieldNames.forEach((fieldName) => {
const itemField = item[fieldName];
const itemValue = Array.isArray(itemField)
? itemField.join(",")
: itemField;
if (item[fieldName]) {
metadata.push([fieldName, item[fieldName]]);
metadata.push([fieldName, itemValue]);
}
});
this.setState({ item, metadata });
Expand All @@ -72,8 +76,12 @@ export class Flyout extends React.Component<IFlyoutProps, IFlyoutState> {
const fieldNames = this.props.otherMetadataFieldNames;
const metadata: Array<Array<string | number | boolean>> = [];
fieldNames.forEach((fieldName) => {
const itemField = item[fieldName];
const itemValue = Array.isArray(itemField)
? itemField.join(",")
: itemField;
if (item[fieldName]) {
metadata.push([fieldName, item[fieldName]]);
metadata.push([fieldName, itemValue]);
}
});
this.setState({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ export class ImageList extends React.Component<
if (!item) {
return;
}
const predictedY = item?.predictedY;
const alt = Array.isArray(predictedY) ? predictedY.join(",") : predictedY;

return (
<Stack
Expand All @@ -122,7 +124,7 @@ export class ImageList extends React.Component<
>
<Image
{...imageProps}
alt={item?.predictedY}
alt={alt}
src={`data:image/jpg;base64,${item?.image}`}
onClick={this.callbackWrapper(item)}
width={this.props.imageDim}
Expand Down
Loading

0 comments on commit 69d7553

Please sign in to comment.