diff --git a/.eslintrc/.eslintrc.custom.eslintrc b/.eslintrc/.eslintrc.custom.eslintrc index 012a404627..21dad46b5f 100644 --- a/.eslintrc/.eslintrc.custom.eslintrc +++ b/.eslintrc/.eslintrc.custom.eslintrc @@ -123,6 +123,7 @@ "target_column", "task_type", "test_data", + "time_series_id_column_names", "treatment_feature", "treatment_gains", "tree_features", diff --git a/apps/dashboard/src/model-assessment-forecasting/App.tsx b/apps/dashboard/src/model-assessment-forecasting/App.tsx index 92b8224dee..9801e10bfe 100644 --- a/apps/dashboard/src/model-assessment-forecasting/App.tsx +++ b/apps/dashboard/src/model-assessment-forecasting/App.tsx @@ -2,7 +2,6 @@ // Licensed under the MIT License. import { ITheme } from "@fluentui/react"; -import { HelpMessageDict } from "@responsible-ai/error-analysis"; import { Language } from "@responsible-ai/localization"; import { ModelAssessmentDashboard, @@ -11,6 +10,12 @@ import { } from "@responsible-ai/model-assessment"; import React from "react"; +import { + bobsSandwichesSandwich, + giorgiosPizzeriaBoston, + nonnasCannoliBoston +} from "./__mock_data__/mockForecastingData"; + interface IAppProps extends IModelAssessmentData { theme: ITheme; language: Language; @@ -20,28 +25,62 @@ interface IAppProps extends IModelAssessmentData { } export class App extends React.Component { - private messages: HelpMessageDict = { - LocalExpAndTestReq: [{ displayText: "LocalExpAndTestReq", format: "text" }], - LocalOrGlobalAndTestReq: [ - { displayText: "LocalOrGlobalAndTestReq", format: "text" } - ], - PredictorReq: [{ displayText: "PredictorReq", format: "text" }], - TestReq: [{ displayText: "TestReq", format: "text" }] - }; - public render(): React.ReactNode { this.props.modelExplanationData?.forEach( (modelExplanationData) => (modelExplanationData.modelClass = "blackbox") ); const modelAssessmentDashboardProps: IModelAssessmentDashboardProps = { ...this.props, - cohortData: [], + cohortData: [ + giorgiosPizzeriaBoston, + nonnasCannoliBoston, + bobsSandwichesSandwich + ], locale: this.props.language, localUrl: "https://www.bing.com/", - stringParams: { contextualHelp: this.messages }, - theme: this.props.theme + requestForecast: this.requestForecast }; return ; } + + private requestForecast = ( + x: any[], + abortSignal: AbortSignal + ): Promise => { + return new Promise((resolver) => { + setTimeout(() => { + if (abortSignal.aborted) { + return; + } + let start: number; + let end: number; + if (x[0][0].arg[0] === 1) { + // Giorgio's pizzeria + start = 0; + end = 10; + } else if (x[0][0].arg[0] === 0) { + // Bob's sandwiches + start = 10; + end = 20; + } else { + // Nonna's cannolis + start = 20; + end = 30; + } + const preds = this.props.dataset.predicted_y?.slice( + start, + end + ) as number[]; + if (x[2].length === 0) { + // return original predictions + resolver(preds); + } else { + // return predictions based on modified features + // we have to mock this part since we don't have a model available + resolver(preds.map((p) => p + 200 * (Math.random() - 0.5))); + } + }, 300); + }); + }; } diff --git a/apps/dashboard/src/model-assessment-forecasting/__mock_data__/mockForecastingData.ts b/apps/dashboard/src/model-assessment-forecasting/__mock_data__/mockForecastingData.ts index beed6b36a9..19d8fd514d 100644 --- a/apps/dashboard/src/model-assessment-forecasting/__mock_data__/mockForecastingData.ts +++ b/apps/dashboard/src/model-assessment-forecasting/__mock_data__/mockForecastingData.ts @@ -1,17 +1,152 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -import { DatasetTaskType, IDataset } from "@responsible-ai/core-ui"; +import { + DatasetTaskType, + FilterMethods, + IDataset, + IPreBuiltCohort +} from "@responsible-ai/core-ui"; + +export const giorgiosPizzeriaBoston: IPreBuiltCohort = { + cohort_filter_list: [ + { + arg: ["Giorgio's pizzeria"], + column: "restaurant", + method: FilterMethods.Includes + }, + { + arg: ["Boston, MA"], + column: "city", + method: FilterMethods.Includes + } + ], + name: "restaurant = Giorgio's pizzeria, city = Boston, MA" +}; + +export const nonnasCannoliBoston: IPreBuiltCohort = { + cohort_filter_list: [ + { + arg: ["Nonna's cannoli"], + column: "restaurant", + method: FilterMethods.Includes + }, + { + arg: ["Boston, MA"], + column: "city", + method: FilterMethods.Includes + } + ], + name: "restaurant = Nonna's cannoli, city = Boston, MA" +}; + +export const bobsSandwichesSandwich: IPreBuiltCohort = { + cohort_filter_list: [ + { + arg: ["Bob's sandwiches"], + column: "restaurant", + method: FilterMethods.Includes + }, + { + arg: ["Sandwich, MA"], + column: "city", + method: FilterMethods.Includes + } + ], + name: "restaurant = Bob's sandwiches, city = Sandwich, MA" +}; // Based on how much money is spent on ads and the daily outside temperature // predict the number of people dining at a restaurant on any given day. export const mockForecastingData: IDataset = { - categorical_features: [], - feature_metadata: {}, - feature_names: [], + categorical_features: ["restaurant", "city"], + feature_metadata: { + categorical_features: ["restaurant", "city"], + time_series_id_column_names: ["restaurant", "city"] + }, + feature_names: ["ads", "temperature", "restaurant", "city"], - features: [], - predicted_y: [], - task_type: DatasetTaskType.Regression, - true_y: [] + features: [ + [1, 56, "Giorgio's pizzeria", "Boston, MA"], + [2, 65, "Giorgio's pizzeria", "Boston, MA"], + [1.3, 43, "Giorgio's pizzeria", "Boston, MA"], + [2.1, 55, "Giorgio's pizzeria", "Boston, MA"], + [1.6, 70, "Giorgio's pizzeria", "Boston, MA"], + [1.9, 67, "Giorgio's pizzeria", "Boston, MA"], + [1.3, 84, "Giorgio's pizzeria", "Boston, MA"], + [2.4, 76, "Giorgio's pizzeria", "Boston, MA"], + [1.9, 73, "Giorgio's pizzeria", "Boston, MA"], + [2.9, 61, "Giorgio's pizzeria", "Boston, MA"], + [0.2, 56, "Nonna's cannoli", "Boston, MA"], + [0.1, 65, "Nonna's cannoli", "Boston, MA"], + [0.4, 43, "Nonna's cannoli", "Boston, MA"], + [0.3, 55, "Nonna's cannoli", "Boston, MA"], + [0.2, 70, "Nonna's cannoli", "Boston, MA"], + [0.1, 67, "Nonna's cannoli", "Boston, MA"], + [0.3, 84, "Nonna's cannoli", "Boston, MA"], + [0.4, 76, "Nonna's cannoli", "Boston, MA"], + [0.3, 73, "Nonna's cannoli", "Boston, MA"], + [0.5, 61, "Nonna's cannoli", "Boston, MA"], + [3, 27, "Bob's sandwiches", "Sandwich, MA"], + [2.5, 31, "Bob's sandwiches", "Sandwich, MA"], + [2.7, 33, "Bob's sandwiches", "Sandwich, MA"], + [3.9, 47, "Bob's sandwiches", "Sandwich, MA"], + [3.4, 91, "Bob's sandwiches", "Sandwich, MA"], + [3.1, 87, "Bob's sandwiches", "Sandwich, MA"], + [1.9, 81, "Bob's sandwiches", "Sandwich, MA"], + [1.8, 34, "Bob's sandwiches", "Sandwich, MA"], + [3.4, 53, "Bob's sandwiches", "Sandwich, MA"], + [3, 62, "Bob's sandwiches", "Sandwich, MA"] + ], + index: [ + "10-10-2022", + "10-11-2022", + "10-12-2022", + "10-13-2022", + "10-14-2022", + "10-15-2022", + "10-16-2022", + "10-17-2022", + "10-18-2022", + "10-19-2022", + "10-10-2022", + "10-11-2022", + "10-12-2022", + "10-13-2022", + "10-14-2022", + "10-15-2022", + "10-16-2022", + "10-17-2022", + "10-18-2022", + "10-19-2022", + "10-10-2022", + "10-11-2022", + "10-12-2022", + "10-13-2022", + "10-14-2022", + "10-15-2022", + "10-16-2022", + "10-17-2022", + "10-18-2022", + "10-19-2022", + "10-10-2022", + "10-11-2022", + "10-12-2022", + "10-13-2022", + "10-14-2022", + "10-15-2022", + "10-16-2022", + "10-17-2022", + "10-18-2022", + "10-19-2022" + ], + predicted_y: [ + 213, 349, 320, 303, 511, 501, 762, 631, 599, 398, 243, 549, 390, 301, 311, + 701, 722, 681, 299, 498, 763, 149, 120, 103, 111, 101, 162, 131, 299, 198 + ], + task_type: DatasetTaskType.Forecasting, + true_y: [ + 240, 310, 342, 392, 514, 501, 795, 621, 600, 422, 222, 500, 345, 678, 343, + 454, 667, 399, 588, 440, 120, 99, 101, 110, 150, 130, 125, 127, 200, 187 + ] }; diff --git a/apps/widget/src/app/ModelAssessment.tsx b/apps/widget/src/app/ModelAssessment.tsx index f325c8500b..723f087bc6 100644 --- a/apps/widget/src/app/ModelAssessment.tsx +++ b/apps/widget/src/app/ModelAssessment.tsx @@ -36,6 +36,7 @@ export class ModelAssessment extends React.Component { | "requestBoxPlotDistribution" | "requestDatasetAnalysisBarChart" | "requestDatasetAnalysisBoxChart" + | "requestForecast" | "requestGlobalCausalEffects" | "requestGlobalCausalPolicy" | "requestGlobalExplanations" @@ -85,6 +86,9 @@ export class ModelAssessment extends React.Component { "/model_overview_probability_distribution" ); }; + callBack.requestForecast = async (data: any[]): Promise => { + return callFlaskService(this.props.config, data, "/forecast"); + }; callBack.requestGlobalCausalEffects = async ( id: string, filter: unknown[], diff --git a/libs/core-ui/src/index.ts b/libs/core-ui/src/index.ts index e1360c3699..a8a2e22261 100644 --- a/libs/core-ui/src/index.ts +++ b/libs/core-ui/src/index.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. export * from "./lib/cohortKey"; +export * from "./lib/Cohort/isAllDataCohort"; export * from "./lib/Cohort/Cohort"; export * from "./lib/Cohort/CohortList/CohortList"; export * from "./lib/Cohort/Constants"; @@ -42,6 +43,7 @@ export * from "./lib/util/Never"; export * from "./lib/util/PartialRequired"; export * from "./lib/util/nameof"; export * from "./lib/util/rowErrorSize"; +export * from "./lib/util/TimeUtils"; export * from "./lib/util/getBoxData"; export * from "./lib/util/getBasicFilterString"; export * from "./lib/util/getCommonStyles"; diff --git a/libs/core-ui/src/lib/Cohort/Cohort.ts b/libs/core-ui/src/lib/Cohort/Cohort.ts index 66bd2b9fd6..edaf0b4de0 100644 --- a/libs/core-ui/src/lib/Cohort/Cohort.ts +++ b/libs/core-ui/src/lib/Cohort/Cohort.ts @@ -15,7 +15,8 @@ export enum CohortSource { None = "None", TreeMap = "Tree map", HeatMap = "Heat map", - ManuallyCreated = "Manually created" + ManuallyCreated = "Manually created", + Prebuilt = "Prebuilt" } export class Cohort { diff --git a/libs/core-ui/src/lib/Cohort/CohortInfo/CohortInfo.tsx b/libs/core-ui/src/lib/Cohort/CohortInfo/CohortInfo.tsx index 8310917b17..cec5822fa9 100644 --- a/libs/core-ui/src/lib/Cohort/CohortInfo/CohortInfo.tsx +++ b/libs/core-ui/src/lib/Cohort/CohortInfo/CohortInfo.tsx @@ -8,6 +8,7 @@ import React from "react"; import { getCohortFilterCount } from "../../util/getCohortFilterCount"; import { ErrorCohortStats } from "../CohortStats"; import { ErrorCohort } from "../ErrorCohort"; +import { isAllDataErrorCohort } from "../isAllDataCohort"; import { PredictionPath } from "../PredictionPath/PredictionPath"; import { cohortInfoStyles } from "./CohortInfo.styles"; @@ -37,8 +38,7 @@ export class CohortInfo extends React.PureComponent { - {this.props.currentCohort.cohort.name !== - localization.ErrorAnalysis.Cohort.defaultLabel && ( + {!isAllDataErrorCohort(this.props.currentCohort, true) && ( {this.props.currentCohort.cohort.name} )} diff --git a/libs/core-ui/src/lib/Cohort/CohortInfoSection/CohortInfoSection.tsx b/libs/core-ui/src/lib/Cohort/CohortInfoSection/CohortInfoSection.tsx index cb0d1a178c..182bff70ae 100644 --- a/libs/core-ui/src/lib/Cohort/CohortInfoSection/CohortInfoSection.tsx +++ b/libs/core-ui/src/lib/Cohort/CohortInfoSection/CohortInfoSection.tsx @@ -10,6 +10,7 @@ import { IModelAssessmentContext, ModelAssessmentContext } from "../../Context/ModelAssessmentContext"; +import { isAllDataErrorCohort } from "../isAllDataCohort"; export interface ICohortInfoSectionProps { toggleShiftCohortVisibility: () => void; @@ -26,10 +27,7 @@ export class CohortInfoSection extends React.PureComponent Promise) | undefined; + requestForecast?: ( + request: any[], + abortSignal: AbortSignal + ) => Promise; shiftErrorCohort(cohort: ErrorCohort): void; addCohort(cohort: Cohort, switchNew?: boolean): void; editCohort(cohort: Cohort, switchNew?: boolean): void; diff --git a/libs/core-ui/src/lib/Interfaces/IDataset.ts b/libs/core-ui/src/lib/Interfaces/IDataset.ts index 650161d3be..2a4fa21a35 100644 --- a/libs/core-ui/src/lib/Interfaces/IDataset.ts +++ b/libs/core-ui/src/lib/Interfaces/IDataset.ts @@ -10,7 +10,8 @@ export enum DatasetTaskType { ImageClassification = "image_classification", TextClassification = "text_classification", MultilabelTextClassification = "multilabel_text_classification", - MultilabelImageClassification = "multilabel_image_classification" + MultilabelImageClassification = "multilabel_image_classification", + Forecasting = "forecasting" } export interface IDataset { @@ -28,6 +29,7 @@ export interface IDataset { data_balance_measures?: IDataBalanceMeasures; feature_metadata?: IFeatureMetaData; images?: string[]; + index?: string[]; } // TODO Remove DatasetSummary when possible diff --git a/libs/core-ui/src/lib/Interfaces/IMetaData.ts b/libs/core-ui/src/lib/Interfaces/IMetaData.ts index b9a0723b14..76741ba3f6 100644 --- a/libs/core-ui/src/lib/Interfaces/IMetaData.ts +++ b/libs/core-ui/src/lib/Interfaces/IMetaData.ts @@ -6,4 +6,5 @@ export interface IFeatureMetaData { datetime_features?: string[]; categorical_features?: string[]; dropped_features?: string[]; + time_series_id_column_names?: string[]; } diff --git a/libs/core-ui/src/lib/util/TimeUtils.ts b/libs/core-ui/src/lib/util/TimeUtils.ts new file mode 100644 index 0000000000..5fbaa8dc07 --- /dev/null +++ b/libs/core-ui/src/lib/util/TimeUtils.ts @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +export function orderByTime( + values: number[], + rowIndices: string[] +): Array<[number, number]> { + return values + .map((predictedValue: number, idx: number) => { + return [Date.parse(rowIndices[idx]), predictedValue] as [number, number]; + }) + .sort( + (objA: [number, number], objB: [number, number]) => objA[0] - objB[1] + ); +} diff --git a/libs/forecasting/src/lib/ForecastingDashboard/Controls/ForecastComparison.tsx b/libs/forecasting/src/lib/ForecastingDashboard/Controls/ForecastComparison.tsx new file mode 100644 index 0000000000..d3d8faf1f0 --- /dev/null +++ b/libs/forecasting/src/lib/ForecastingDashboard/Controls/ForecastComparison.tsx @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { Stack, Text } from "@fluentui/react"; +import { + ModelAssessmentContext, + defaultModelAssessmentContext, + BasicHighChart, + JointDataset, + orderByTime +} from "@responsible-ai/core-ui"; +import { localization } from "@responsible-ai/localization"; +import { SeriesOptionsType } from "highcharts"; +import React from "react"; + +import { forecastingDashboardStyles } from "../ForecastingDashboard.styles"; + +import { getForecastPrediction } from "./getForecastPrediction"; + +export class IForecastComparisonProps {} + +export interface IForecastComparisonState { + timeSeriesId?: number; + baselinePrediction?: Array<[number, number]>; + trueY?: Array<[number, number]>; +} + +const stackTokens = { + childrenGap: "l1" +}; + +export class ForecastComparison extends React.Component< + IForecastComparisonProps, + IForecastComparisonState +> { + public static contextType = ModelAssessmentContext; + public context: React.ContextType = + defaultModelAssessmentContext; + + public constructor(props: IForecastComparisonProps) { + super(props); + this.state = {}; + } + + public async componentDidMount(): Promise { + const trueY = this.getTrueY(); + const baselinePrediction = await this.getBaselineForecastPrediction(); + if (baselinePrediction) { + this.setState({ baselinePrediction, trueY }); + } + } + + public async componentDidUpdate(): Promise { + // Check if the time series was changed. + // In that case, we need to update our state accordingly. + const currentlySelectedTimeSeriesId = + this.context.baseErrorCohort.cohort.getCohortID(); + if (currentlySelectedTimeSeriesId !== this.state.timeSeriesId) { + const trueY = this.getTrueY(); + const baselinePrediction = await this.getBaselineForecastPrediction(); + this.setState({ + baselinePrediction, + timeSeriesId: currentlySelectedTimeSeriesId, + trueY + }); + } + } + + public render(): React.ReactNode { + const classNames = forecastingDashboardStyles(); + + if ( + this.context === undefined || + this.context.jointDataset.numLabels !== 1 + ) { + return; + } + + const trueY: SeriesOptionsType = { + data: this.state.trueY, + name: localization.Forecasting.trueY, + type: "spline" + }; + const seriesData: SeriesOptionsType[] = [trueY]; + if (this.state.baselinePrediction !== undefined) { + seriesData.push({ + data: this.state.baselinePrediction, + name: localization.Forecasting.baselinePrediction, + type: "spline" + } as SeriesOptionsType); + } + + return ( + + + + Compare What-if Forecasts + + + {seriesData !== undefined && ( + + + + )} + + ); + } + + private readonly getBaselineForecastPrediction = async (): Promise< + Array<[number, number]> | undefined + > => { + const baselinePrediction = await getForecastPrediction( + this.context.baseErrorCohort.cohort, + this.context.jointDataset, + this.context.requestForecast + ); + if (baselinePrediction && this.context.dataset.index) { + const dataIndex = this.context.dataset.index; + return orderByTime(baselinePrediction, this.getIndices(dataIndex)); + } + return undefined; + }; + + private readonly getTrueY = (): Array<[number, number]> | undefined => { + if (this.context.dataset.index) { + return orderByTime( + this.context.baseErrorCohort.cohort.filteredData.map( + (row) => row[JointDataset.TrueYLabel] + ), + this.getIndices(this.context.dataset.index) + ); + } + return undefined; + }; + + private readonly getIndices = (dataIndex: string[]): string[] => { + return this.context.baseErrorCohort.cohort.filteredData.map( + (datum) => dataIndex[datum.Index] + ); + }; +} diff --git a/libs/forecasting/src/lib/ForecastingDashboard/Controls/getForecastPrediction.ts b/libs/forecasting/src/lib/ForecastingDashboard/Controls/getForecastPrediction.ts new file mode 100644 index 0000000000..ce8bdbe39e --- /dev/null +++ b/libs/forecasting/src/lib/ForecastingDashboard/Controls/getForecastPrediction.ts @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { Cohort, JointDataset } from "@responsible-ai/core-ui"; + +export async function getForecastPrediction( + cohort: Cohort, + jointDataset: JointDataset, + requestForecast: + | ((request: any[], abortSignal: AbortSignal) => Promise) + | undefined +): Promise { + if (requestForecast === undefined) { + return; + } + return await requestForecast( + [ + Cohort.getLabeledFilters(cohort.filters, jointDataset), + Cohort.getLabeledCompositeFilters(cohort.compositeFilters, jointDataset), + [] + ], + new AbortController().signal + ); +} diff --git a/libs/forecasting/src/lib/ForecastingDashboard/ForecastingDashboard.styles.ts b/libs/forecasting/src/lib/ForecastingDashboard/ForecastingDashboard.styles.ts new file mode 100644 index 0000000000..4faf26e000 --- /dev/null +++ b/libs/forecasting/src/lib/ForecastingDashboard/ForecastingDashboard.styles.ts @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { + IStyle, + mergeStyleSets, + IProcessedStyleSet, + getTheme, + FontWeights +} from "@fluentui/react"; +import { descriptionMaxWidth, flexLgDown } from "@responsible-ai/core-ui"; + +export interface IForecastingDashboardStyles { + dropdown: IStyle; + sectionStack: IStyle; + configurationActionButton: IStyle; + topLevelDescriptionText: IStyle; + descriptionText: IStyle; + generalText: IStyle; + generalSemiBoldText: IStyle; + selections: IStyle; + smallDropdown: IStyle; + mediumText: IStyle; + forecastCategoryText: IStyle; + subMediumText: IStyle; + smallTextField: IStyle; + errorText: IStyle; +} + +export const forecastingDashboardStyles: () => IProcessedStyleSet = + () => { + const theme = getTheme(); + return mergeStyleSets({ + configurationActionButton: { + marginTop: "25px" + }, + descriptionText: { + color: theme.semanticColors.bodyText, + maxWidth: descriptionMaxWidth + }, + dropdown: { + selectors: { + "@media screen and (min-width: 1024px)": { + width: "600px" + } + }, + width: "auto" + }, + errorText: { + color: theme.semanticColors.errorText, + maxWidth: "200px" + }, + forecastCategoryText: { + fontSize: "14px", + fontWeight: "600", + lineHeight: "20px", + marginBottom: "20px" + }, + generalSemiBoldText: { + color: theme.semanticColors.bodyText, + fontWeight: FontWeights.semibold, + maxWidth: descriptionMaxWidth + }, + generalText: { + color: theme.semanticColors.bodyText + }, + mediumText: { + fontSize: "20px", + fontWeight: "600", + maxWidth: "200px" + }, + sectionStack: { + padding: "0 40px 40px 40px" + }, + selections: flexLgDown, + smallDropdown: { + width: "200px" + }, + smallTextField: { + width: "320px" + }, + subMediumText: { + fontSize: "16px", + fontWeight: "600", + lineHeight: "20px" + }, + topLevelDescriptionText: { + color: theme.semanticColors.bodyText, + maxWidth: descriptionMaxWidth + } + }); + }; diff --git a/libs/forecasting/src/lib/ForecastingDashboard/ForecastingDashboard.tsx b/libs/forecasting/src/lib/ForecastingDashboard/ForecastingDashboard.tsx index 6f33efe4fd..dda3a2bf01 100644 --- a/libs/forecasting/src/lib/ForecastingDashboard/ForecastingDashboard.tsx +++ b/libs/forecasting/src/lib/ForecastingDashboard/ForecastingDashboard.tsx @@ -1,19 +1,97 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -import { Text } from "@fluentui/react"; +import { Dropdown, IDropdownOption, Stack, Text } from "@fluentui/react"; import { defaultModelAssessmentContext, + isAllDataErrorCohort, ModelAssessmentContext } from "@responsible-ai/core-ui"; +import { localization } from "@responsible-ai/localization"; import React from "react"; -export class ForecastingDashboard extends React.Component { +import { ForecastComparison } from "./Controls/ForecastComparison"; +import { forecastingDashboardStyles } from "./ForecastingDashboard.styles"; + +export class IForecastingDashboardProps {} + +export class IForecastingDashboardState {} + +export class ForecastingDashboard extends React.Component< + IForecastingDashboardProps, + IForecastingDashboardState +> { public static contextType = ModelAssessmentContext; public context: React.ContextType = defaultModelAssessmentContext; public render(): React.ReactNode { - return Placeholder; + const classNames = forecastingDashboardStyles(); + + if ( + this.context === undefined || + this.context.baseErrorCohort === undefined + ) { + return; + } + + // "All data" cohort selected, so no particular time series selected yet. + // special case: only 1 time series in dataset, needs to be handled! TODO + const noCohortSelected = isAllDataErrorCohort(this.context.baseErrorCohort); + + const dropdownOptions: IDropdownOption[] = this.context.errorCohorts + .filter((cohort) => !isAllDataErrorCohort(cohort)) + .map((cohort) => { + return { + key: cohort.cohort.getCohortID(), + text: cohort.cohort.name + }; + }); + + return ( + + + {localization.Forecasting.whatIfDescription} + + + + + {!noCohortSelected && ( + + + + )} + + ); } + + private onChangeCohort = ( + _event: React.FormEvent, + option?: IDropdownOption | undefined + ): void => { + if (option) { + const newCohortId = option.key as number; + const newCohort = this.context.errorCohorts.find( + (cohort) => cohort.cohort.getCohortID() === newCohortId + ); + if (newCohort) { + this.context.shiftErrorCohort(newCohort); + } + } + }; } diff --git a/libs/localization/src/lib/en.json b/libs/localization/src/lib/en.json index bcacc02d9f..ce5d7ce9de 100644 --- a/libs/localization/src/lib/en.json +++ b/libs/localization/src/lib/en.json @@ -1420,7 +1420,8 @@ "ModelOverview": "Model overview component added", "CausalAnalysis": "Causal analysis component added", "Counterfactuals": "Counterfactuals component added", - "Vision": "Vision data explorer component added" + "Vision": "Vision data explorer component added", + "Forecasting": "Forecasting what-if component added" } }, "CausalAnalysis": { @@ -1468,6 +1469,7 @@ "ErrorAnalysis": "Error analysis", "Fairness": "Fairness", "FeatureImportances": "Feature importances", + "Forecasting": "Forecasting", "ModelOverview": "Model overview", "TableView": "Table view", "VisionTab": "Vision data explorer" @@ -1820,5 +1822,44 @@ "TableViewTab": { "Heading": "View the dataset in a table format for all features and rows." } + }, + "Forecasting": { + "target": "Target", + "whatIfHeader": "What-if analysis", + "whatIfDescription": "What-if allows you to perturb features for any input and observe how the model's prediction changes. You can perturb features manually or specify the desired prediction (e.g., class label for a classifier) to see a list of closest data points to the original input that would lead to the desired prediction. Also known as prediction counterfactuals, you can use them for exploring the relationships learnt by the model; understanding important, necessary features for the model's predictions; or debug edge-cases for the model. To start, choose input points from the data table or scatter plot.", + "timeSeries": "Time series", + "selectTimeSeries": "Select a time series.", + "trueY": "True Y", + "baselinePrediction": "Baseline prediction", + "forecastComparisonChartTitle": "Forecasts", + "forecastComparisonChartTimeAxisLabel": "Time", + "Transformations": { + "multiply": "Multiply", + "divide": "Divide", + "add": "Add", + "subtract": "Subtract" + }, + "TransformationCreation": { + "title": "Create what-if scenario", + "nameLabel": "What-if scenario name", + "featureInstructions": "Choose a feature to perturb.", + "operationInstructions": "Choose an operation to apply to the feature.", + "operationDropdownHeader": "Operation", + "featureDropdownHeader": "Feature", + "valueSpinButtonHeader": "Value", + "scenarioNamingInstructionsPlaceholder": "Enter a unique name", + "scenarioNamingInstructions": "Enter a name for your what-if scenario.", + "scenarioNamingCollisionMessage": "This name exists already. Please enter a unique name.", + "valueErrorMessage": "For operation {0} please select a value between {1} and {2} other than {3}.", + "invalidCombinationErrorMessage": "This is identical to an existing what-if scenario. Please change the feature, operation, or value.", + "addTransformationButton": "Add Transformation", + "divisionAndMultiplicationBy": "by" + }, + "TransformationTable": { + "nameColumnHeader": "Name", + "methodColumnHeader": "Method", + "divisionAndMultiplicationBy": "by ", + "header": "What-if Forecasts ({0})" + } } } diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/AvailableTabs.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/AvailableTabs.ts index d9db917ce6..2bee36a0ff 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/AvailableTabs.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/AvailableTabs.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import { IDropdownOption } from "@fluentui/react"; +import { DatasetTaskType } from "@responsible-ai/core-ui"; import { localization } from "@responsible-ai/localization"; import { IModelAssessmentDashboardProps } from "./ModelAssessmentDashboardProps"; @@ -21,7 +22,12 @@ export function getAvailableTabs( text: localization.ModelAssessment.ComponentNames.ErrorAnalysis }); } - + if (props.dataset.task_type === DatasetTaskType.Forecasting) { + availableTabs.push({ + key: GlobalTabKeys.ForecastingTab, + text: localization.ModelAssessment.ComponentNames.Forecasting + }); + } if (props.dataset.images) { availableTabs.push({ key: GlobalTabKeys.VisionTab, @@ -29,16 +35,22 @@ export function getAvailableTabs( }); } - if (props.dataset.predicted_y) { + if ( + props.dataset.predicted_y && + props.dataset.task_type !== DatasetTaskType.Forecasting + ) { availableTabs.push({ key: GlobalTabKeys.ModelOverviewTab, text: localization.ModelAssessment.ComponentNames.ModelOverview }); } - availableTabs.push({ - key: GlobalTabKeys.DataAnalysisTab, - text: localization.ModelAssessment.ComponentNames.DataAnalysis - }); + + if (props.dataset.task_type !== DatasetTaskType.Forecasting) { + availableTabs.push({ + key: GlobalTabKeys.DataAnalysisTab, + text: localization.ModelAssessment.ComponentNames.DataAnalysis + }); + } if (props.modelExplanationData && props.modelExplanationData.length > 0) { availableTabs.push({ diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ChangeGlobalCohort.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ChangeGlobalCohort.tsx index d304b6cd75..b8d790af91 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ChangeGlobalCohort.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ChangeGlobalCohort.tsx @@ -12,6 +12,7 @@ import React from "react"; import { ShiftCohort } from "./ShiftCohort"; interface IChangeGlobalCohortProps { visible: boolean; + showAllDataCohort: boolean; onDismiss(): void; } export class ChangeGlobalCohort extends React.Component { @@ -24,6 +25,7 @@ export class ChangeGlobalCohort extends React.Component ) ); diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ChangeGlobalCohortButton.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ChangeGlobalCohortButton.tsx index f3a1346453..d0adb53f48 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ChangeGlobalCohortButton.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ChangeGlobalCohortButton.tsx @@ -12,17 +12,21 @@ import React from "react"; import { ChangeGlobalCohort } from "./ChangeGlobalCohort"; +interface IChangeGlobalCohortButtonProps { + showAllDataCohort: boolean; +} + interface IChangeGlobalCohortButtonState { shiftCohortVisible: boolean; } export class ChangeGlobalCohortButton extends React.Component< - Record, + IChangeGlobalCohortButtonProps, IChangeGlobalCohortButtonState > { public static contextType = ModelAssessmentContext; public context: IModelAssessmentContext = defaultModelAssessmentContext; - public constructor(props: Record) { + public constructor(props: IChangeGlobalCohortButtonProps) { super(props); this.state = { shiftCohortVisible: false }; } @@ -36,6 +40,7 @@ export class ChangeGlobalCohortButton extends React.Component< ); diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/CohortList.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/CohortList.tsx index b0ab2fa189..3fbbdb8668 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/CohortList.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/CohortList.tsx @@ -18,6 +18,7 @@ import { ErrorCohort, getCohortFilterCount, IModelAssessmentContext, + isAllDataErrorCohort, ModelAssessmentContext } from "@responsible-ai/core-ui"; import { localization } from "@responsible-ai/localization"; @@ -31,6 +32,7 @@ export interface ICohortListProps { onEditCohortClick?: (editedCohort: ErrorCohort) => void; onRemoveCohortClick?: (editedCohort: ErrorCohort) => void; enableEditing: boolean; + showAllDataCohort: boolean; } export interface ICohortListItem { @@ -64,7 +66,7 @@ export class CohortList extends React.Component< fieldName: "name", isResizable: true, key: "nameColumn", - maxWidth: 200, + maxWidth: 400, minWidth: 50, name: "Name" }, @@ -178,13 +180,15 @@ export class CohortList extends React.Component< - - - + {this.props.enableEditing && ( + + + + )} ); } @@ -288,6 +292,11 @@ export class CohortList extends React.Component< private getCohortListItems(): ICohortListItem[] { const allItems = this.context.errorCohorts + .filter( + (errorCohort: ErrorCohort) => + this.props.showAllDataCohort || + !isAllDataErrorCohort(errorCohort, true) + ) .filter((errorCohort: ErrorCohort) => !errorCohort.isTemporary) .map((errorCohort: ErrorCohort, index: number) => { const details = [ diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/CohortSettingsPanel.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/CohortSettingsPanel.tsx index 57408a86a9..d52de82840 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/CohortSettingsPanel.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/CohortSettingsPanel.tsx @@ -17,6 +17,8 @@ import { CreateGlobalCohortButton } from "./CreateGlobalCohortButton"; export interface ICohortSettingsPanelProps { isOpen: boolean; onDismiss: () => void; + allowCohortEditing: boolean; + showAllDataCohort: boolean; } export class CohortSettingsPanel extends React.PureComponent { @@ -49,15 +51,22 @@ export class CohortSettingsPanel extends React.PureComponent - - - - + + {this.props.allowCohortEditing && ( + + + + )} - + diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ProcessPreBuiltCohort.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ProcessPreBuiltCohort.ts index c1f40e2473..e77e61aedf 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ProcessPreBuiltCohort.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ProcessPreBuiltCohort.ts @@ -8,7 +8,8 @@ import { IsClassifier, FilterMethods, Cohort, - IPreBuiltFilter + IPreBuiltFilter, + CohortSource } from "@responsible-ai/core-ui"; import { localization } from "@responsible-ai/localization"; @@ -96,7 +97,9 @@ export function processPreBuiltCohort( } const errorCohortEntry = new ErrorCohort( new Cohort(preBuiltCohort.name, jointDataset, filterList), - jointDataset + jointDataset, + undefined, + CohortSource.Prebuilt ); errorCohortList.push(errorCohortEntry); } diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ShiftCohort.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ShiftCohort.tsx index 34b6e33512..3ae107336d 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ShiftCohort.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ShiftCohort.tsx @@ -15,6 +15,7 @@ import { CohortEditorFilterList, defaultModelAssessmentContext, ErrorCohort, + isAllDataErrorCohort, ModelAssessmentContext } from "@responsible-ai/core-ui"; import { localization } from "@responsible-ai/localization"; @@ -24,6 +25,7 @@ export interface IShiftCohortProps { onDismiss: () => void; onApply: (selectedCohort: ErrorCohort) => void; defaultCohort?: ErrorCohort; + showAllDataCohort: boolean; } export interface IShiftCohortState { @@ -41,9 +43,14 @@ export class ShiftCohort extends React.Component< defaultModelAssessmentContext; public componentDidMount(): void { - const savedCohorts = this.context.errorCohorts.filter( - (errorCohort) => !errorCohort.isTemporary - ); + const savedCohorts = this.context.errorCohorts + .filter((errorCohort) => !errorCohort.isTemporary) + .filter( + (errorCohort) => + !errorCohort.isTemporary && + (this.props.showAllDataCohort || + !isAllDataErrorCohort(errorCohort, true)) + ); const options: IDropdownOption[] = savedCohorts.map( (savedCohort: ErrorCohort, index: number) => { return { key: index, text: savedCohort.cohort.name }; diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/MainMenu.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/MainMenu.tsx index 97abef5b35..034f2f4a18 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/MainMenu.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/MainMenu.tsx @@ -12,9 +12,11 @@ import { TooltipHost } from "@fluentui/react"; import { + DatasetTaskType, defaultModelAssessmentContext, ErrorCohort, IModelAssessmentContext, + isAllDataErrorCohort, ITelemetryEvent, ModelAssessmentContext, TelemetryEventName, @@ -93,7 +95,9 @@ export class MainMenu extends React.PureComponent< public render(): React.ReactNode { const classNames = mainMenuStyles(); - const menuItems: ICommandBarItemProps[] = [ + let allowCohortEditing = true; + let showAllDataCohort = true; + let menuItems: ICommandBarItemProps[] = [ { className: classNames.mainMenuItem, key: "cohortName", @@ -116,6 +120,15 @@ export class MainMenu extends React.PureComponent< text: localization.ModelAssessment.CohortInformation.NewCohort } ]; + + if (this.context.dataset.task_type === DatasetTaskType.Forecasting) { + // Creating and switching cohorts is handled differently for forecasting + // since we need to work with time series as cohorts only. + menuItems = []; + allowCohortEditing = false; + showAllDataCohort = false; + } + return ( <>
@@ -130,6 +143,8 @@ export class MainMenu extends React.PureComponent< { @@ -131,21 +137,21 @@ export class TabsView extends React.PureComponent< const disabledView = this.props.requestDebugML === undefined && this.props.requestMatrix === undefined && - this.props.baseCohort.cohort.name !== - localization.ErrorAnalysis.Cohort.defaultLabel; + !isAllDataErrorCohort(this.props.baseCohort, true); const classNames = tabsViewStyles(); return ( {this.props.activeGlobalTabs[0]?.key !== - GlobalTabKeys.ErrorAnalysisTab && ( - - - - )} + GlobalTabKeys.ErrorAnalysisTab && + this.context.dataset.task_type !== DatasetTaskType.Forecasting && ( + + + + )} {this.props.activeGlobalTabs.map((t, i) => ( <> )} + {t.key === GlobalTabKeys.ForecastingTab && ( + <> +

+ + {localization.Forecasting.whatIfHeader} + +

+ + + )} {t.key === GlobalTabKeys.ErrorAnalysisTab && this.props.errorAnalysisData?.[0] && ( <> @@ -428,13 +447,15 @@ export class TabsView extends React.PureComponent< )}
- - - + {this.context.dataset.task_type !== DatasetTaskType.Forecasting && ( + + + + )} ))} {this.state.mapShiftVisible && ( diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx index 7d8e4c84b5..b98d322707 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboard.tsx @@ -84,6 +84,7 @@ export class ModelAssessmentDashboard extends CohortBasedComponent< requestDatasetAnalysisBoxChart: this.props.requestDatasetAnalysisBoxChart, requestExp: this.props.requestExp, + requestForecast: this.props.requestForecast, requestGlobalCausalEffects: this.props.requestGlobalCausalEffects, requestGlobalCausalPolicy: this.props.requestGlobalCausalPolicy, requestGlobalExplanations: this.props.requestGlobalExplanations, diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts index 9745eaaf09..8dafc8e37f 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts @@ -120,7 +120,10 @@ export interface IModelAssessmentDashboardProps abortSignal: AbortSignal ) => Promise; localUrl?: string; - + requestForecast?: ( + request: any[], + abortSignal: AbortSignal + ) => Promise; telemetryHook?: (message: ITelemetryEvent) => void; // TODO figure out how to persist starting tab for fairness diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentEnums.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentEnums.ts index f976fb99a5..798c064e63 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentEnums.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentEnums.ts @@ -9,5 +9,6 @@ export enum GlobalTabKeys { ModelOverviewTab = "ModelOverviewTab", CausalAnalysisTab = "CausalAnalysisTab", CounterfactualsTab = "CounterfactualsTab", - VisionTab = "VisionExplanationTab" + VisionTab = "VisionExplanationTab", + ForecastingTab = "ForecastingTab" } diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/addTabMessage.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/addTabMessage.ts index fcce82f6e5..83fd2d381e 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/addTabMessage.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/addTabMessage.ts @@ -24,6 +24,8 @@ export function addTabMessage(tab: GlobalTabKeys): string { return strings.ModelOverview; case GlobalTabKeys.VisionTab: return strings.Vision; + case GlobalTabKeys.ForecastingTab: + return strings.Forecasting; default: throw new Error(`Unexpected component ${tab}.`); }