diff --git a/extension/src/experiments/index.ts b/extension/src/experiments/index.ts index b04a2af4bb..b9c3744481 100644 --- a/extension/src/experiments/index.ts +++ b/extension/src/experiments/index.ts @@ -439,6 +439,10 @@ export class Experiments extends BaseRepository { return this.experiments.getFinishedExperiments() } + public getExperiments() { + return this.experiments.getExperiments() + } + public getExperimentDisplayName(experimentId: string) { const experiment = this.experiments .getCombinedList() @@ -501,6 +505,10 @@ export class Experiments extends BaseRepository { return this.columns.getFirstThreeColumnOrder() } + public getColumnTerminalNodes() { + return this.columns.getTerminalNodes() + } + public getHasData() { if (this.deferred.state === 'none') { return diff --git a/extension/src/persistence/constants.ts b/extension/src/persistence/constants.ts index 96fe84accf..7ba16c5f87 100644 --- a/extension/src/persistence/constants.ts +++ b/extension/src/persistence/constants.ts @@ -10,6 +10,7 @@ export enum PersistenceKey { PLOT_COMPARISON_ORDER = 'plotComparisonOrder:', PLOT_COMPARISON_PATHS_ORDER = 'plotComparisonPathsOrder', PLOT_METRIC_ORDER = 'plotMetricOrder:', + PLOTS_CUSTOM_ORDER = 'plotCustomOrder:', PLOT_SECTION_COLLAPSED = 'plotSectionCollapsed:', PLOT_SELECTED_METRICS = 'plotSelectedMetrics:', PLOT_SIZES = 'plotSizes:', diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index 624de4ab05..0b81f32648 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -6,12 +6,14 @@ import { collectCheckpointPlotsData, collectTemplates, collectMetricOrder, - collectOverrideRevisionDetails + collectOverrideRevisionDetails, + collectCustomPlotsData } from './collect' import plotsDiffFixture from '../../test/fixtures/plotsDiff/output' import expShowFixture from '../../test/fixtures/expShow/base/output' import modifiedFixture from '../../test/fixtures/expShow/modified/output' import checkpointPlotsFixture from '../../test/fixtures/expShow/base/checkpointPlots' +import customPlotsFixture from '../../test/fixtures/expShow/base/customPlots' import { ExperimentsOutput, ExperimentStatus, @@ -27,6 +29,62 @@ const logsLossPath = join('logs', 'loss.tsv') const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot +describe('collectCustomPlotsData', () => { + it('should return the expected data from the text fixture', () => { + const data = collectCustomPlotsData( + [ + { + metric: 'metrics:summary.json:loss', + param: 'params:params.yaml:dropout' + }, + { + metric: 'metrics:summary.json:accuracy', + param: 'params:params.yaml:epochs' + } + ], + [ + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.4668000042438507, + loss: 2.0205044746398926 + } + }, + name: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 16 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.3484833240509033, + loss: 1.9293040037155151 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.25, epochs: 10 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.6768440509033, + loss: 2.298503875732422 + } + }, + name: 'exp-f13bca', + params: { 'params.yaml': { dropout: 0.32, epochs: 20 } } + } + ] + ) + expect(data).toStrictEqual(customPlotsFixture.plots) + }) +}) + describe('collectCheckpointPlotsData', () => { it('should return the expected data from the test fixture', () => { const data = collectCheckpointPlotsData(expShowFixture) diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index f0c0739bfd..8153ebedc8 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -1,6 +1,8 @@ import omit from 'lodash.omit' +import get from 'lodash.get' import { TopLevelSpec } from 'vega-lite' import { VisualizationSpec } from 'react-vega' +import { CustomPlotsOrderValue } from '.' import { getRevisionFirstThreeColumns } from './util' import { ColorScale, @@ -13,7 +15,8 @@ import { TemplatePlotEntry, TemplatePlotSection, PlotsType, - Revision + Revision, + CustomPlotData } from '../webview/contract' import { EXPERIMENT_WORKSPACE_ID, @@ -28,9 +31,11 @@ import { import { extractColumns } from '../../experiments/columns/extract' import { decodeColumn, - appendColumnToPath + appendColumnToPath, + splitColumnPath } from '../../experiments/columns/paths' import { + ColumnType, Experiment, isRunning, MetricOrParamColumns @@ -243,6 +248,48 @@ export const collectCheckpointPlotsData = ( return plotsData } +export const getCustomPlotId = (metric: string, param: string) => + `custom-${metric}-${param}` + +const collectCustomPlotData = ( + metric: string, + param: string, + experiments: Experiment[] +): CustomPlotData => { + const splitUpMetricPath = splitColumnPath(metric) + const splitUpParamPath = splitColumnPath(param) + const plotData: CustomPlotData = { + id: getCustomPlotId(metric, param), + metric: metric.slice(ColumnType.METRICS.length + 1), + param: param.slice(ColumnType.PARAMS.length + 1), + values: [] + } + + for (const experiment of experiments) { + const metricValue = get(experiment, splitUpMetricPath) as number | undefined + const paramValue = get(experiment, splitUpParamPath) as number | undefined + + if (metricValue !== undefined && paramValue !== undefined) { + plotData.values.push({ + expName: experiment.name || experiment.label, + metric: metricValue, + param: paramValue + }) + } + } + + return plotData +} + +export const collectCustomPlotsData = ( + metricsAndParams: CustomPlotsOrderValue[], + experiments: Experiment[] +): CustomPlotData[] => { + return metricsAndParams.map(({ metric, param }) => + collectCustomPlotData(metric, param, experiments) + ) +} + type MetricOrderAccumulator = { newOrder: string[] uncollectedMetrics: string[] diff --git a/extension/src/plots/model/index.test.ts b/extension/src/plots/model/index.test.ts index 186bdb6bdb..85833fde07 100644 --- a/extension/src/plots/model/index.test.ts +++ b/extension/src/plots/model/index.test.ts @@ -103,6 +103,7 @@ describe('plotsModel', () => { const expectedSectionCollapsed = { [Section.CHECKPOINT_PLOTS]: true, [Section.TEMPLATE_PLOTS]: false, + [Section.CUSTOM_PLOTS]: false, [Section.COMPARISON_TABLE]: false } diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index 1cdcf8836e..4ded2e8d1e 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -11,7 +11,9 @@ import { RevisionData, TemplateAccumulator, collectCommitRevisionDetails, - collectOverrideRevisionDetails + collectOverrideRevisionDetails, + collectCustomPlotsData, + getCustomPlotId } from './collect' import { getRevisionFirstThreeColumns } from './util' import { @@ -24,7 +26,8 @@ import { DEFAULT_SECTION_SIZES, Section, SectionCollapsed, - PlotSizeNumber + PlotSizeNumber, + CustomPlotData } from '../webview/contract' import { ExperimentsOutput, @@ -46,10 +49,13 @@ import { } from '../multiSource/collect' import { isDvcError } from '../../cli/dvc/reader' +export type CustomPlotsOrderValue = { metric: string; param: string } + export class PlotsModel extends ModelWithPersistence { private readonly experiments: Experiments private plotSizes: Record + private customPlotsOrder: CustomPlotsOrderValue[] private sectionCollapsed: SectionCollapsed private commitRevisions: Record = {} @@ -64,6 +70,7 @@ export class PlotsModel extends ModelWithPersistence { private multiSourceEncoding: MultiSourceEncoding = {} private checkpointPlots?: CheckpointPlot[] + private customPlots?: CustomPlotData[] private selectedMetrics?: string[] private metricOrder: string[] @@ -89,6 +96,8 @@ export class PlotsModel extends ModelWithPersistence { undefined ) this.metricOrder = this.revive(PersistenceKey.PLOT_METRIC_ORDER, []) + + this.customPlotsOrder = this.revive(PersistenceKey.PLOTS_CUSTOM_ORDER, []) } public transformAndSetExperiments(data: ExperimentsOutput) { @@ -102,6 +111,8 @@ export class PlotsModel extends ModelWithPersistence { this.setMetricOrder() + this.recreateCustomPlots() + return this.removeStaleData() } @@ -119,6 +130,8 @@ export class PlotsModel extends ModelWithPersistence { collectMultiSourceVariations(data, this.multiSourceVariations) ]) + this.recreateCustomPlots() + this.comparisonData = { ...this.comparisonData, ...comparisonData @@ -127,7 +140,6 @@ export class PlotsModel extends ModelWithPersistence { ...this.revisionData, ...revisionData } - this.templates = { ...this.templates, ...templates } this.multiSourceVariations = multiSourceVariations this.multiSourceEncoding = collectMultiSourceEncoding( @@ -171,6 +183,49 @@ export class PlotsModel extends ModelWithPersistence { } } + public getCustomPlots() { + if (!this.customPlots) { + return + } + return { + plots: this.customPlots, + size: this.getPlotSize(Section.CUSTOM_PLOTS) + } + } + + public recreateCustomPlots() { + const customPlots: CustomPlotData[] = collectCustomPlotsData( + this.getCustomPlotsOrder(), + this.experiments.getExperiments() + ) + this.customPlots = customPlots + } + + public getCustomPlotsOrder() { + return this.customPlotsOrder + } + + public setCustomPlotsOrder(plotsOrder: CustomPlotsOrderValue[]) { + this.customPlotsOrder = plotsOrder + this.persist(PersistenceKey.PLOTS_CUSTOM_ORDER, this.customPlotsOrder) + this.recreateCustomPlots() + } + + public removeCustomPlots(plotIds: string[]) { + const newCustomPlotsOrder = this.getCustomPlotsOrder().filter( + ({ metric, param }) => { + return !plotIds.includes(getCustomPlotId(metric, param)) + } + ) + + this.setCustomPlotsOrder(newCustomPlotsOrder) + } + + public addCustomPlot(metricAndParam: CustomPlotsOrderValue) { + const newCustomPlotsOrder = [...this.getCustomPlotsOrder(), metricAndParam] + this.setCustomPlotsOrder(newCustomPlotsOrder) + } + public setupManualRefresh(id: string) { this.deleteRevisionData(id) } diff --git a/extension/src/plots/model/quickPick.test.ts b/extension/src/plots/model/quickPick.test.ts new file mode 100644 index 0000000000..363b3503a8 --- /dev/null +++ b/extension/src/plots/model/quickPick.test.ts @@ -0,0 +1,128 @@ +import { CustomPlotsOrderValue } from '.' +import { pickCustomPlots, pickMetricAndParam } from './quickPick' +import { quickPickManyValues, quickPickValue } from '../../vscode/quickPick' +import { Title } from '../../vscode/title' +import { Toast } from '../../vscode/toast' +import { ColumnType } from '../../experiments/webview/contract' + +jest.mock('../../vscode/quickPick') +jest.mock('../../vscode/toast') + +const mockedQuickPickValue = jest.mocked(quickPickValue) +const mockedQuickPickManyValues = jest.mocked(quickPickManyValues) + +const mockedToast = jest.mocked(Toast) +const mockedShowError = jest.fn() +mockedToast.showError = mockedShowError + +beforeEach(() => { + jest.resetAllMocks() +}) + +describe('pickCustomPlots', () => { + it('should return early given no plots', async () => { + const undef = await pickCustomPlots([], 'There are no plots to select.', {}) + expect(undef).toBeUndefined() + expect(mockedQuickPickManyValues).not.toHaveBeenCalled() + expect(mockedShowError).toHaveBeenCalledTimes(1) + }) + + it('should return the selected plots', async () => { + const selectedPlots = [ + 'custom-metrics:summary.json:loss-params:params.yaml:dropout', + 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs' + ] + const mockedExperiments = [ + { + metric: 'metrics:summary.json:loss', + param: 'params:params.yaml:dropout' + }, + { + metric: 'metrics:summary.json:accuracy', + param: 'params:params.yaml:epochs' + }, + { + metric: 'metrics:summary.json:learning_rate', + param: 'param:summary.json:process.threshold' + } + ] as CustomPlotsOrderValue[] + + mockedQuickPickManyValues.mockResolvedValueOnce(selectedPlots) + const picked = await pickCustomPlots( + mockedExperiments, + 'There are no plots to remove.', + { title: Title.SELECT_CUSTOM_PLOTS_TO_REMOVE } + ) + + expect(picked).toStrictEqual(selectedPlots) + expect(mockedQuickPickManyValues).toHaveBeenCalledTimes(1) + expect(mockedQuickPickManyValues).toHaveBeenCalledWith( + [ + { + description: + 'metrics:summary.json:loss vs params:params.yaml:dropout', + label: 'loss vs dropout', + value: 'custom-metrics:summary.json:loss-params:params.yaml:dropout' + }, + { + description: + 'metrics:summary.json:accuracy vs params:params.yaml:epochs', + label: 'accuracy vs epochs', + value: + 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs' + }, + { + description: + 'metrics:summary.json:learning_rate vs param:summary.json:process.threshold', + label: 'learning_rate vs threshold', + value: + 'custom-metrics:summary.json:learning_rate-param:summary.json:process.threshold' + } + ], + { title: Title.SELECT_CUSTOM_PLOTS_TO_REMOVE } + ) + }) +}) + +describe('pickMetricAndParam', () => { + it('should end early if there are no metrics or params available', async () => { + mockedQuickPickValue.mockResolvedValueOnce(undefined) + const undef = await pickMetricAndParam([]) + expect(undef).toBeUndefined() + expect(mockedShowError).toHaveBeenCalledTimes(1) + }) + + it('should return a metric and a param if both are selected by the user', async () => { + const expectedMetric = { + label: 'loss', + path: 'metrics:summary.json:loss' + } + const expectedParam = { + label: 'epochs', + path: 'summary.json:loss-params:params.yaml:epochs' + } + mockedQuickPickValue + .mockResolvedValueOnce(expectedMetric) + .mockResolvedValueOnce(expectedParam) + const metricAndParam = await pickMetricAndParam([ + { ...expectedMetric, hasChildren: false, type: ColumnType.METRICS }, + { ...expectedParam, hasChildren: false, type: ColumnType.PARAMS }, + { + hasChildren: false, + label: 'dropout', + path: 'params:params.yaml:dropout', + type: ColumnType.PARAMS + }, + { + hasChildren: false, + label: 'accuracy', + path: 'summary.json:accuracy', + type: ColumnType.METRICS + } + ]) + expect(metricAndParam).toStrictEqual({ + metric: expectedMetric.path, + param: expectedParam.path + }) + }) +}) diff --git a/extension/src/plots/model/quickPick.ts b/extension/src/plots/model/quickPick.ts new file mode 100644 index 0000000000..eee9723a3e --- /dev/null +++ b/extension/src/plots/model/quickPick.ts @@ -0,0 +1,70 @@ +import { CustomPlotsOrderValue } from '.' +import { getCustomPlotId } from './collect' +import { splitColumnPath } from '../../experiments/columns/paths' +import { pickFromColumnLikes } from '../../experiments/columns/quickPick' +import { Column, ColumnType } from '../../experiments/webview/contract' +import { definedAndNonEmpty } from '../../util/array' +import { + quickPickManyValues, + QuickPickOptionsWithTitle +} from '../../vscode/quickPick' +import { Title } from '../../vscode/title' +import { Toast } from '../../vscode/toast' + +export const pickCustomPlots = ( + plots: CustomPlotsOrderValue[], + noPlotsErrorMessage: string, + quickPickOptions: QuickPickOptionsWithTitle +): Thenable => { + if (!definedAndNonEmpty(plots)) { + return Toast.showError(noPlotsErrorMessage) + } + + const plotsItems = plots.map(({ metric, param }) => { + const splitMetric = splitColumnPath(metric) + const splitParam = splitColumnPath(param) + return { + description: `${metric} vs ${param}`, + label: `${splitMetric[splitMetric.length - 1]} vs ${ + splitParam[splitParam.length - 1] + }`, + value: getCustomPlotId(metric, param) + } + }) + + return quickPickManyValues(plotsItems, quickPickOptions) +} + +const getTypeColumnLikes = (columns: Column[], columnType: ColumnType) => + columns + .filter(({ type }) => type === columnType) + .map(({ label, path }) => ({ label, path })) + +export const pickMetricAndParam = async (columns: Column[]) => { + const metricColumnLikes = getTypeColumnLikes(columns, ColumnType.METRICS) + const paramColumnLikes = getTypeColumnLikes(columns, ColumnType.PARAMS) + + if ( + !definedAndNonEmpty(metricColumnLikes) || + !definedAndNonEmpty(paramColumnLikes) + ) { + return Toast.showError('There are no metrics or params to select from.') + } + + const metric = await pickFromColumnLikes(metricColumnLikes, { + title: Title.SELECT_METRIC_CUSTOM_PLOT + }) + + if (!metric) { + return + } + + const param = await pickFromColumnLikes(paramColumnLikes, { + title: Title.SELECT_PARAM_CUSTOM_PLOT + }) + + if (!param) { + return + } + return { metric: metric.path, param: param.path } +} diff --git a/extension/src/plots/webview/contract.ts b/extension/src/plots/webview/contract.ts index bbd44b798c..6ef496c4f6 100644 --- a/extension/src/plots/webview/contract.ts +++ b/extension/src/plots/webview/contract.ts @@ -11,19 +11,22 @@ export const PlotSizeNumber = { export enum Section { CHECKPOINT_PLOTS = 'checkpoint-plots', TEMPLATE_PLOTS = 'template-plots', - COMPARISON_TABLE = 'comparison-table' + COMPARISON_TABLE = 'comparison-table', + CUSTOM_PLOTS = 'custom-plots' } export const DEFAULT_SECTION_SIZES = { [Section.CHECKPOINT_PLOTS]: PlotSizeNumber.REGULAR, [Section.TEMPLATE_PLOTS]: PlotSizeNumber.REGULAR, - [Section.COMPARISON_TABLE]: PlotSizeNumber.REGULAR + [Section.COMPARISON_TABLE]: PlotSizeNumber.REGULAR, + [Section.CUSTOM_PLOTS]: PlotSizeNumber.REGULAR } export const DEFAULT_SECTION_COLLAPSED = { [Section.CHECKPOINT_PLOTS]: false, [Section.TEMPLATE_PLOTS]: false, - [Section.COMPARISON_TABLE]: false + [Section.COMPARISON_TABLE]: false, + [Section.CUSTOM_PLOTS]: false } export type SectionCollapsed = typeof DEFAULT_SECTION_COLLAPSED @@ -70,6 +73,24 @@ export type CheckpointPlot = { values: CheckpointPlotValues } +export type CustomPlotValues = { + expName: string + metric: number + param: number +} + +export type CustomPlotData = { + id: string + values: CustomPlotValues[] + metric: string + param: string +} + +export type CustomPlotsData = { + plots: CustomPlotData[] + size: number +} + export type CheckpointPlotData = CheckpointPlot & { title: string } export type CheckpointPlotsData = { @@ -134,6 +155,7 @@ export type ComparisonPlot = { export enum PlotsDataKeys { COMPARISON = 'comparison', CHECKPOINT = 'checkpoint', + CUSTOM = 'custom', HAS_UNSELECTED_PLOTS = 'hasUnselectedPlots', HAS_PLOTS = 'hasPlots', SELECTED_REVISIONS = 'selectedRevisions', @@ -145,6 +167,7 @@ export type PlotsData = | { [PlotsDataKeys.COMPARISON]?: PlotsComparisonData | null [PlotsDataKeys.CHECKPOINT]?: CheckpointPlotsData | null + [PlotsDataKeys.CUSTOM]?: CustomPlotsData | null [PlotsDataKeys.HAS_PLOTS]?: boolean [PlotsDataKeys.HAS_UNSELECTED_PLOTS]?: boolean [PlotsDataKeys.SELECTED_REVISIONS]?: Revision[] diff --git a/extension/src/plots/webview/messages.ts b/extension/src/plots/webview/messages.ts index e60fc9e941..8d75e6aa1a 100644 --- a/extension/src/plots/webview/messages.ts +++ b/extension/src/plots/webview/messages.ts @@ -21,6 +21,11 @@ import { PlotsModel } from '../model' import { PathsModel } from '../paths/model' import { BaseWebview } from '../../webview' import { getModifiedTime } from '../../fileSystem' +import { pickCustomPlots, pickMetricAndParam } from '../model/quickPick' +import { Title } from '../../vscode/title' +import { ColumnType } from '../../experiments/webview/contract' +import { FILE_SEPARATOR } from '../../experiments/columns/paths' +import { reorderObjectList } from '../../util/array' export class WebviewMessages { private readonly paths: PathsModel @@ -54,6 +59,7 @@ export class WebviewMessages { void this.getWebview()?.show({ checkpoint: this.getCheckpointPlots(), comparison: this.getComparisonPlots(overrideComparison), + custom: this.getCustomPlots(), hasPlots: !!this.paths.hasPaths(), hasUnselectedPlots: this.paths.getHasUnselectedPlots(), sectionCollapsed: this.plots.getSectionCollapsed(), @@ -70,6 +76,8 @@ export class WebviewMessages { public handleMessageFromWebview(message: MessageFromWebview) { switch (message.type) { + case MessageFromWebviewType.ADD_CUSTOM_PLOT: + return this.addCustomPlot() case MessageFromWebviewType.TOGGLE_METRIC: return this.setSelectedMetrics(message.payload) case MessageFromWebviewType.RESIZE_PLOTS: @@ -84,10 +92,14 @@ export class WebviewMessages { return this.setTemplateOrder(message.payload) case MessageFromWebviewType.REORDER_PLOTS_METRICS: return this.setMetricOrder(message.payload) + case MessageFromWebviewType.REORDER_PLOTS_CUSTOM: + return this.setCustomPlotsOrder(message.payload) case MessageFromWebviewType.SELECT_PLOTS: return this.selectPlotsFromWebview() case MessageFromWebviewType.SELECT_EXPERIMENTS: return this.selectExperimentsFromWebview() + case MessageFromWebviewType.REMOVE_CUSTOM_PLOTS: + return this.removeCustomPlots() case MessageFromWebviewType.REFRESH_REVISION: return this.attemptToRefreshRevData(message.payload) case MessageFromWebviewType.REFRESH_REVISIONS: @@ -158,6 +170,80 @@ export class WebviewMessages { this.sendCheckpointPlotsAndEvent(EventName.VIEWS_REORDER_PLOTS_METRICS) } + private async addCustomPlot() { + const metricAndParam = await pickMetricAndParam( + this.experiments.getColumnTerminalNodes() + ) + + if (!metricAndParam) { + return + } + + const plotAlreadyExists = this.plots + .getCustomPlotsOrder() + .some( + ({ param, metric }) => + param === metricAndParam.param && metric === metricAndParam.metric + ) + + if (plotAlreadyExists) { + return Toast.showError('Custom plot already exists.') + } + + this.plots.addCustomPlot(metricAndParam) + this.sendCustomPlots() + sendTelemetryEvent( + EventName.VIEWS_PLOTS_CUSTOM_PLOT_ADDED, + undefined, + undefined + ) + } + + private async removeCustomPlots() { + const selectedPlotsIds = await pickCustomPlots( + this.plots.getCustomPlotsOrder(), + 'There are no plots to remove.', + { + title: Title.SELECT_CUSTOM_PLOTS_TO_REMOVE + } + ) + + if (!selectedPlotsIds) { + return + } + + this.plots.removeCustomPlots(selectedPlotsIds) + this.sendCustomPlots() + sendTelemetryEvent( + EventName.VIEWS_PLOTS_CUSTOM_PLOT_REMOVED, + undefined, + undefined + ) + } + + private setCustomPlotsOrder(plotIds: string[]) { + const customPlots = this.plots.getCustomPlots()?.plots + if (!customPlots) { + return + } + + const buildMetricOrParamPath = (type: string, path: string) => + type + FILE_SEPARATOR + path + const newOrder = reorderObjectList(plotIds, customPlots, 'id').map( + ({ metric, param }) => ({ + metric: buildMetricOrParamPath(ColumnType.METRICS, metric), + param: buildMetricOrParamPath(ColumnType.PARAMS, param) + }) + ) + this.plots.setCustomPlotsOrder(newOrder) + this.sendCustomPlots() + sendTelemetryEvent( + EventName.VIEWS_REORDER_PLOTS_CUSTOM, + undefined, + undefined + ) + } + private selectPlotsFromWebview() { void this.selectPlots() sendTelemetryEvent(EventName.VIEWS_PLOTS_SELECT_PLOTS, undefined, undefined) @@ -234,6 +320,12 @@ export class WebviewMessages { }) } + private sendCustomPlots() { + void this.getWebview()?.show({ + custom: this.getCustomPlots() + }) + } + private getTemplatePlots(overrideRevs?: Revision[]) { const paths = this.paths.getTemplateOrder() const plots = this.plots.getTemplatePlots(paths, overrideRevs) @@ -295,4 +387,8 @@ export class WebviewMessages { private getCheckpointPlots() { return this.plots.getCheckpointPlots() || null } + + private getCustomPlots() { + return this.plots.getCustomPlots() || null + } } diff --git a/extension/src/telemetry/constants.ts b/extension/src/telemetry/constants.ts index e7c28a8f75..5013e35731 100644 --- a/extension/src/telemetry/constants.ts +++ b/extension/src/telemetry/constants.ts @@ -63,6 +63,8 @@ export const EventName = Object.assign( VIEWS_PLOTS_COMPARISON_ROWS_REORDERED: 'views.plots.comparisonRowsReordered', VIEWS_PLOTS_CREATED: 'views.plots.created', + VIEWS_PLOTS_CUSTOM_PLOT_ADDED: 'views.plots.addCustomPlot', + VIEWS_PLOTS_CUSTOM_PLOT_REMOVED: 'views.plots.removeCustomPlot', VIEWS_PLOTS_EXPERIMENT_TOGGLE: 'views.plots.toggleExperimentStatus', VIEWS_PLOTS_FOCUS_CHANGED: 'views.plots.focusChanged', VIEWS_PLOTS_MANUAL_REFRESH: 'views.plots.manualRefresh', @@ -72,6 +74,7 @@ export const EventName = Object.assign( VIEWS_PLOTS_SECTION_TOGGLE: 'views.plots.toggleSection', VIEWS_PLOTS_SELECT_EXPERIMENTS: 'view.plots.selectExperiments', VIEWS_PLOTS_SELECT_PLOTS: 'view.plots.selectPlots', + VIEWS_REORDER_PLOTS_CUSTOM: 'views.plots.customReordered', VIEWS_REORDER_PLOTS_METRICS: 'views.plots.metricsReordered', VIEWS_REORDER_PLOTS_TEMPLATES: 'views.plots.templatesReordered', @@ -242,6 +245,8 @@ export interface IEventNamePropertyMapping { [EventName.VIEWS_PLOTS_CLOSED]: undefined [EventName.VIEWS_PLOTS_CREATED]: undefined + [EventName.VIEWS_PLOTS_CUSTOM_PLOT_REMOVED]: undefined + [EventName.VIEWS_PLOTS_CUSTOM_PLOT_ADDED]: undefined [EventName.VIEWS_PLOTS_FOCUS_CHANGED]: WebviewFocusChangedProperties [EventName.VIEWS_PLOTS_MANUAL_REFRESH]: { revisions: number } [EventName.VIEWS_PLOTS_METRICS_SELECTED]: undefined @@ -253,6 +258,7 @@ export interface IEventNamePropertyMapping { [EventName.VIEWS_PLOTS_SELECT_PLOTS]: undefined [EventName.VIEWS_PLOTS_EXPERIMENT_TOGGLE]: undefined [EventName.VIEWS_REORDER_PLOTS_METRICS]: undefined + [EventName.VIEWS_REORDER_PLOTS_CUSTOM]: undefined [EventName.VIEWS_REORDER_PLOTS_TEMPLATES]: undefined [EventName.VIEWS_PLOTS_PATH_TREE_OPENED]: DvcRootCount diff --git a/extension/src/test/fixtures/expShow/base/customPlots.ts b/extension/src/test/fixtures/expShow/base/customPlots.ts new file mode 100644 index 0000000000..f2be8b5992 --- /dev/null +++ b/extension/src/test/fixtures/expShow/base/customPlots.ts @@ -0,0 +1,56 @@ +import { + CustomPlotsData, + PlotSizeNumber +} from '../../../../plots/webview/contract' + +const data: CustomPlotsData = { + plots: [ + { + id: 'custom-metrics:summary.json:loss-params:params.yaml:dropout', + metric: 'summary.json:loss', + param: 'params.yaml:dropout', + values: [ + { + expName: 'exp-e7a67', + metric: 2.0205044746398926, + param: 0.15 + }, + { + expName: 'exp-83425', + metric: 1.9293040037155151, + param: 0.25 + }, + { + expName: 'exp-f13bca', + metric: 2.298503875732422, + param: 0.32 + } + ] + }, + { + id: 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs', + metric: 'summary.json:accuracy', + param: 'params.yaml:epochs', + values: [ + { + expName: 'exp-e7a67', + metric: 0.4668000042438507, + param: 16 + }, + { + expName: 'exp-83425', + metric: 0.3484833240509033, + param: 10 + }, + { + expName: 'exp-f13bca', + metric: 0.6768440509033, + param: 20 + } + ] + } + ], + size: PlotSizeNumber.REGULAR +} + +export default data diff --git a/extension/src/test/suite/plots/index.test.ts b/extension/src/test/suite/plots/index.test.ts index 29d36b7901..9a3f86363c 100644 --- a/extension/src/test/suite/plots/index.test.ts +++ b/extension/src/test/suite/plots/index.test.ts @@ -8,6 +8,7 @@ import { buildPlots } from '../plots/util' import { Disposable } from '../../../extension' import expShowFixtureWithoutErrors from '../../fixtures/expShow/base/noErrors' import checkpointPlotsFixture from '../../fixtures/expShow/base/checkpointPlots' +import customPlotsFixture from '../../fixtures/expShow/base/customPlots' import plotsDiffFixture from '../../fixtures/plotsDiff/output' import multiSourcePlotsDiffFixture from '../../fixtures/plotsDiff/multiSource' import templatePlotsFixture from '../../fixtures/plotsDiff/template' @@ -40,6 +41,7 @@ import { EXPERIMENT_WORKSPACE_ID } from '../../../cli/dvc/contract' import { SelectedExperimentWithColor } from '../../../experiments/model' +import * as customPlotQuickPickUtil from '../../../plots/model/quickPick' suite('Plots Test Suite', () => { const disposable = Disposable.fn() @@ -508,6 +510,55 @@ suite('Plots Test Suite', () => { ) }).timeout(WEBVIEW_TEST_TIMEOUT) + it('should handle a custom plots reordered message from the webview', async () => { + const { plots, plotsModel, messageSpy } = await buildPlots( + disposable, + plotsDiffFixture + ) + + const webview = await plots.showWebview() + + const mockNewCustomPlotsOrder = [ + 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs', + 'custom-metrics:summary.json:loss-params:params.yaml:dropout' + ] + + stub(plotsModel, 'getCustomPlots') + .onFirstCall() + .returns(customPlotsFixture) + + const mockSendTelemetryEvent = stub(Telemetry, 'sendTelemetryEvent') + const mockMessageReceived = getMessageReceivedEmitter(webview) + const mockSetCustomPlotsOrder = stub(plotsModel, 'setCustomPlotsOrder') + mockSetCustomPlotsOrder.returns(undefined) + + messageSpy.resetHistory() + + mockMessageReceived.fire({ + payload: mockNewCustomPlotsOrder, + type: MessageFromWebviewType.REORDER_PLOTS_CUSTOM + }) + + expect(mockSetCustomPlotsOrder).to.be.calledOnce + expect(mockSetCustomPlotsOrder).to.be.calledWithExactly([ + { + metric: 'metrics:summary.json:accuracy', + param: 'params:params.yaml:epochs' + }, + { + metric: 'metrics:summary.json:loss', + param: 'params:params.yaml:dropout' + } + ]) + expect(messageSpy).to.be.calledOnce + expect(mockSendTelemetryEvent).to.be.calledOnce + expect(mockSendTelemetryEvent).to.be.calledWithExactly( + EventName.VIEWS_REORDER_PLOTS_CUSTOM, + undefined, + undefined + ) + }).timeout(WEBVIEW_TEST_TIMEOUT) + it('should handle a select experiments message from the webview', async () => { const { plots, experiments } = await buildPlots( disposable, @@ -674,6 +725,7 @@ suite('Plots Test Suite', () => { const expectedPlotsData: TPlotsData = { checkpoint: checkpointPlotsFixture, comparison: comparisonPlotsFixture, + custom: { plots: [], size: 2 }, hasPlots: true, hasUnselectedPlots: false, sectionCollapsed: DEFAULT_SECTION_COLLAPSED, @@ -798,5 +850,98 @@ suite('Plots Test Suite', () => { undefined ) }).timeout(WEBVIEW_TEST_TIMEOUT) + + it('should handle a add custom plot message from the webview', async () => { + const { plots, plotsModel } = await buildPlots( + disposable, + plotsDiffFixture + ) + + const webview = await plots.showWebview() + + const mockGetMetricAndParam = stub( + customPlotQuickPickUtil, + 'pickMetricAndParam' + ) + + const quickPickEvent = new Promise(resolve => + mockGetMetricAndParam.callsFake(() => { + resolve(undefined) + return Promise.resolve({ + metric: 'metrics:summary.json:loss', + param: 'params:params.yaml:dropout' + }) + }) + ) + + const mockSetCustomPlotsOrder = stub(plotsModel, 'setCustomPlotsOrder') + mockSetCustomPlotsOrder.returns(undefined) + + const mockSendTelemetryEvent = stub(Telemetry, 'sendTelemetryEvent') + const mockMessageReceived = getMessageReceivedEmitter(webview) + + mockMessageReceived.fire({ type: MessageFromWebviewType.ADD_CUSTOM_PLOT }) + + await quickPickEvent + + expect(mockSetCustomPlotsOrder).to.be.calledWith([ + { + metric: 'metrics:summary.json:loss', + param: 'params:params.yaml:dropout' + } + ]) + expect(mockSendTelemetryEvent).to.be.calledWith( + EventName.VIEWS_PLOTS_CUSTOM_PLOT_ADDED, + undefined + ) + }) + + it('should handle a remove custom plot message from the webview', async () => { + const { plots, plotsModel } = await buildPlots( + disposable, + plotsDiffFixture + ) + + const webview = await plots.showWebview() + + const mockSelectCustomPlots = stub( + customPlotQuickPickUtil, + 'pickCustomPlots' + ) + + const quickPickEvent = new Promise(resolve => + mockSelectCustomPlots.callsFake(() => { + resolve(undefined) + return Promise.resolve([ + 'custom-metrics:summary.json:loss-params:params.yaml:dropout' + ]) + }) + ) + + stub(plotsModel, 'getCustomPlotsOrder').returns([ + { + metric: 'metrics:summary.json:loss', + param: 'params:params.yaml:dropout' + } + ]) + + const mockSetCustomPlotsOrder = stub(plotsModel, 'setCustomPlotsOrder') + mockSetCustomPlotsOrder.returns(undefined) + + const mockSendTelemetryEvent = stub(Telemetry, 'sendTelemetryEvent') + const mockMessageReceived = getMessageReceivedEmitter(webview) + + mockMessageReceived.fire({ + type: MessageFromWebviewType.REMOVE_CUSTOM_PLOTS + }) + + await quickPickEvent + + expect(mockSetCustomPlotsOrder).to.be.calledWith([]) + expect(mockSendTelemetryEvent).to.be.calledWith( + EventName.VIEWS_PLOTS_CUSTOM_PLOT_REMOVED, + undefined + ) + }) }) }) diff --git a/extension/src/vscode/title.ts b/extension/src/vscode/title.ts index 179d7f81c4..39be6d0c11 100644 --- a/extension/src/vscode/title.ts +++ b/extension/src/vscode/title.ts @@ -23,6 +23,9 @@ export enum Title { SELECT_OPERATOR = 'Select an Operator', SELECT_PARAM_OR_METRIC_FILTER = 'Select a Param or Metric to Filter by', SELECT_PARAM_OR_METRIC_SORT = 'Select a Param or Metric to Sort by', + SELECT_METRIC_CUSTOM_PLOT = 'Select a Metric to Create a Custom Plot', + SELECT_PARAM_CUSTOM_PLOT = 'Select a Param to Create a Custom Plot', + SELECT_CUSTOM_PLOTS_TO_REMOVE = 'Select Custom Plot(s) to Remove', SELECT_PARAM_TO_MODIFY = 'Select Param(s) to Modify', SELECT_PLOTS = 'Select Plots to Display', SELECT_QUEUE_KILL = 'Select Queue Task(s) to Kill', diff --git a/extension/src/webview/contract.ts b/extension/src/webview/contract.ts index 00d071112b..c0127bb093 100644 --- a/extension/src/webview/contract.ts +++ b/extension/src/webview/contract.ts @@ -15,6 +15,7 @@ export enum MessageFromWebviewType { ADD_CONFIGURATION = 'add-configuration', APPLY_EXPERIMENT_TO_WORKSPACE = 'apply-experiment-to-workspace', ADD_STARRED_EXPERIMENT_FILTER = 'add-starred-experiment-filter', + ADD_CUSTOM_PLOT = 'add-custom-plot', CREATE_BRANCH_FROM_EXPERIMENT = 'create-branch-from-experiment', FOCUS_FILTERS_TREE = 'focus-filters-tree', FOCUS_SORTS_TREE = 'focus-sorts-tree', @@ -28,6 +29,7 @@ export enum MessageFromWebviewType { REORDER_PLOTS_COMPARISON = 'reorder-plots-comparison', REORDER_PLOTS_COMPARISON_ROWS = 'reorder-plots-comparison-rows', REORDER_PLOTS_METRICS = 'reorder-plots-metrics', + REORDER_PLOTS_CUSTOM = 'reorder-plots-custom', REORDER_PLOTS_TEMPLATES = 'reorder-plots-templates', REFRESH_REVISION = 'refresh-revision', REFRESH_REVISIONS = 'refresh-revisions', @@ -51,6 +53,7 @@ export enum MessageFromWebviewType { SHARE_EXPERIMENT_AS_COMMIT = 'share-experiment-as-commit', TOGGLE_METRIC = 'toggle-metric', TOGGLE_PLOTS_SECTION = 'toggle-plots-section', + REMOVE_CUSTOM_PLOTS = 'remove-custom-plots', MODIFY_EXPERIMENT_PARAMS_AND_QUEUE = 'modify-experiment-params-and-queue', MODIFY_EXPERIMENT_PARAMS_AND_RUN = 'modify-experiment-params-and-run', MODIFY_EXPERIMENT_PARAMS_RESET_AND_RUN = 'modify-experiment-params-reset-and-run', @@ -78,6 +81,9 @@ export type PlotsTemplatesReordered = { }[] export type MessageFromWebview = + | { + type: MessageFromWebviewType.ADD_CUSTOM_PLOT + } | { type: MessageFromWebviewType.REORDER_COLUMNS payload: string[] @@ -150,6 +156,9 @@ export type MessageFromWebview = type: MessageFromWebviewType.TOGGLE_METRIC payload: string[] } + | { + type: MessageFromWebviewType.REMOVE_CUSTOM_PLOTS + } | { type: MessageFromWebviewType.REORDER_PLOTS_COMPARISON payload: string[] @@ -162,6 +171,10 @@ export type MessageFromWebview = type: MessageFromWebviewType.REORDER_PLOTS_METRICS payload: string[] } + | { + type: MessageFromWebviewType.REORDER_PLOTS_CUSTOM + payload: string[] + } | { type: MessageFromWebviewType.RESIZE_PLOTS payload: PlotsResizedPayload diff --git a/webview/src/plots/components/App.test.tsx b/webview/src/plots/components/App.test.tsx index d0d562c877..973b0ba295 100644 --- a/webview/src/plots/components/App.test.tsx +++ b/webview/src/plots/components/App.test.tsx @@ -14,6 +14,7 @@ import { import '@testing-library/jest-dom/extend-expect' import comparisonTableFixture from 'dvc/src/test/fixtures/plotsDiff/comparison' import checkpointPlotsFixture from 'dvc/src/test/fixtures/expShow/base/checkpointPlots' +import customPlotsFixture from 'dvc/src/test/fixtures/expShow/base/customPlots' import plotsRevisionsFixture from 'dvc/src/test/fixtures/plotsDiff/revisions' import templatePlotsFixture from 'dvc/src/test/fixtures/plotsDiff/template/webview' import smoothTemplatePlotContent from 'dvc/src/test/fixtures/plotsDiff/template/smoothTemplatePlot' @@ -82,6 +83,16 @@ jest.mock('./checkpointPlots/util', () => ({ width: 100 }) })) +jest.mock('./customPlots/util', () => ({ + createSpec: () => ({ + $schema: 'https://vega.github.io/schema/vega-lite/v5.json', + encoding: {}, + height: 100, + layer: [], + transform: [], + width: 100 + }) +})) jest.spyOn(console, 'warn').mockImplementation(() => {}) const { postMessage } = vsCodeApi @@ -101,7 +112,8 @@ describe('App', () => { const sectionPosition = { [Section.CHECKPOINT_PLOTS]: 2, [Section.TEMPLATE_PLOTS]: 0, - [Section.COMPARISON_TABLE]: 1 + [Section.COMPARISON_TABLE]: 1, + [Section.CUSTOM_PLOTS]: 3 } const sendSetDataMessage = (data: PlotsData) => { @@ -350,7 +362,8 @@ describe('App', () => { expect(screen.getByText('Trends')).toBeInTheDocument() expect(screen.getByText('Data Series')).toBeInTheDocument() expect(screen.getByText('Images')).toBeInTheDocument() - expect(screen.getByText('No Plots to Display')).toBeInTheDocument() + expect(screen.getByText('Custom')).toBeInTheDocument() + expect(screen.getAllByText('No Plots to Display')).toHaveLength(2) expect(screen.getByText('No Images to Compare')).toBeInTheDocument() }) @@ -361,7 +374,31 @@ describe('App', () => { expect(screen.queryByText('Loading Plots...')).not.toBeInTheDocument() expect(screen.getByText('Trends')).toBeInTheDocument() - expect(screen.getByText('No Plots to Display')).toBeInTheDocument() + expect(screen.getAllByText('No Plots to Display')).toHaveLength(2) + }) + + it('should render other sections given a message with only custom plots data', () => { + renderAppWithOptionalData({ + custom: customPlotsFixture + }) + + expect(screen.queryByText('Loading Plots...')).not.toBeInTheDocument() + expect(screen.getByText('Trends')).toBeInTheDocument() + expect(screen.getByText('Data Series')).toBeInTheDocument() + expect(screen.getByText('Images')).toBeInTheDocument() + expect(screen.getByText('Custom')).toBeInTheDocument() + expect(screen.getAllByText('No Plots to Display')).toHaveLength(2) + expect(screen.getByText('No Images to Compare')).toBeInTheDocument() + }) + + it('should render custom even when there is no custom plots data', () => { + renderAppWithOptionalData({ + comparison: comparisonTableFixture + }) + + expect(screen.queryByText('Loading Plots...')).not.toBeInTheDocument() + expect(screen.getByText('Custom')).toBeInTheDocument() + expect(screen.getAllByText('No Plots to Display')).toHaveLength(3) }) it('should render the comparison table when given a message with comparison plots data', () => { @@ -403,6 +440,7 @@ describe('App', () => { renderAppWithOptionalData({ checkpoint: checkpointPlotsFixture, comparison: comparisonTableFixture, + custom: customPlotsFixture, template: templatePlotsFixture }) @@ -796,6 +834,58 @@ describe('App', () => { ]) }) + it('should add a custom plot if a user creates a custom plot', () => { + renderAppWithOptionalData({ + custom: { + ...customPlotsFixture, + plots: customPlotsFixture.plots.slice(1) + } + }) + + expect( + screen.getAllByTestId(/summary\.json/).map(plot => plot.id) + ).toStrictEqual([ + 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs' + ]) + + sendSetDataMessage({ + custom: customPlotsFixture + }) + + expect( + screen.getAllByTestId(/summary\.json/).map(plot => plot.id) + ).toStrictEqual([ + 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs', + 'custom-metrics:summary.json:loss-params:params.yaml:dropout' + ]) + }) + + it('should remove a custom plot if a user deletes a custom plot', () => { + renderAppWithOptionalData({ + custom: customPlotsFixture + }) + + expect( + screen.getAllByTestId(/summary\.json/).map(plot => plot.id) + ).toStrictEqual([ + 'custom-metrics:summary.json:loss-params:params.yaml:dropout', + 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs' + ]) + + sendSetDataMessage({ + custom: { + ...customPlotsFixture, + plots: customPlotsFixture.plots.slice(1) + } + }) + + expect( + screen.getAllByTestId(/summary\.json/).map(plot => plot.id) + ).toStrictEqual([ + 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs' + ]) + }) + it('should not change the metric order in the hover menu by reordering the plots', () => { renderAppWithOptionalData({ checkpoint: checkpointPlotsFixture @@ -1373,10 +1463,11 @@ describe('App', () => { renderAppWithOptionalData({ checkpoint: checkpointPlotsFixture, comparison: comparisonTableFixture, + custom: customPlotsFixture, template: complexTemplatePlotsFixture }) - const [templateInfo, comparisonInfo, checkpointInfo] = + const [templateInfo, comparisonInfo, checkpointInfo, customInfo] = screen.getAllByTestId('info-tooltip-toggle') fireEvent.mouseEnter(templateInfo, { bubbles: true }) @@ -1387,12 +1478,16 @@ describe('App', () => { fireEvent.mouseEnter(checkpointInfo, { bubbles: true }) expect(screen.getByTestId('tooltip-checkpoint-plots')).toBeInTheDocument() + + fireEvent.mouseEnter(customInfo, { bubbles: true }) + expect(screen.getByTestId('tooltip-custom-plots')).toBeInTheDocument() }) it('should dismiss a tooltip by pressing esc', () => { renderAppWithOptionalData({ checkpoint: checkpointPlotsFixture, comparison: comparisonTableFixture, + custom: customPlotsFixture, template: complexTemplatePlotsFixture }) diff --git a/webview/src/plots/components/App.tsx b/webview/src/plots/components/App.tsx index 182178c862..67e7ead963 100644 --- a/webview/src/plots/components/App.tsx +++ b/webview/src/plots/components/App.tsx @@ -2,6 +2,7 @@ import React, { useCallback } from 'react' import { useDispatch } from 'react-redux' import { CheckpointPlotsData, + CustomPlotsData, PlotsComparisonData, PlotsData, PlotsDataKeys, @@ -16,6 +17,10 @@ import { setCollapsed as setCheckpointPlotsCollapsed, update as updateCheckpointPlots } from './checkpointPlots/checkpointPlotsSlice' +import { + setCollapsed as setCustomPlotsCollapsed, + update as updateCustomPlots +} from './customPlots/customPlotsSlice' import { setCollapsed as setComparisonTableCollapsed, update as updateComparisonTable @@ -39,6 +44,7 @@ const dispatchCollapsedSections = ( ) => { if (sections) { dispatch(setCheckpointPlotsCollapsed(sections[Section.CHECKPOINT_PLOTS])) + dispatch(setCustomPlotsCollapsed(sections[Section.CUSTOM_PLOTS])) dispatch(setComparisonTableCollapsed(sections[Section.COMPARISON_TABLE])) dispatch(setTemplatePlotsCollapsed(sections[Section.TEMPLATE_PLOTS])) } @@ -56,6 +62,9 @@ export const feedStore = ( case PlotsDataKeys.CHECKPOINT: dispatch(updateCheckpointPlots(data.data[key] as CheckpointPlotsData)) continue + case PlotsDataKeys.CUSTOM: + dispatch(updateCustomPlots(data.data[key] as CustomPlotsData)) + continue case PlotsDataKeys.COMPARISON: dispatch(updateComparisonTable(data.data[key] as PlotsComparisonData)) continue diff --git a/webview/src/plots/components/Plots.tsx b/webview/src/plots/components/Plots.tsx index 8f990648df..0bc90dc245 100644 --- a/webview/src/plots/components/Plots.tsx +++ b/webview/src/plots/components/Plots.tsx @@ -3,6 +3,7 @@ import { useSelector, useDispatch } from 'react-redux' import { AddPlots, Welcome } from './GetStarted' import { ZoomedInPlot } from './ZoomedInPlot' import { CheckpointPlotsWrapper } from './checkpointPlots/CheckpointPlotsWrapper' +import { CustomPlotsWrapper } from './customPlots/CustomPlotsWrapper' import { TemplatePlotsWrapper } from './templatePlots/TemplatePlotsWrapper' import { ComparisonTableWrapper } from './comparisonTable/ComparisonTableWrapper' import { Ribbon } from './ribbon/Ribbon' @@ -32,6 +33,7 @@ const PlotsContent = () => { const hasTemplateData = useSelector( (state: PlotsState) => state.template.hasData ) + const hasCustomData = useSelector((state: PlotsState) => state.custom.hasData) const wrapperRef = createRef() useLayoutEffect(() => { @@ -52,7 +54,12 @@ const PlotsContent = () => { return Loading Plots... } - if (!hasCheckpointData && !hasComparisonData && !hasTemplateData) { + if ( + !hasCheckpointData && + !hasComparisonData && + !hasTemplateData && + !hasCustomData + ) { return ( { + {zoomedInPlot?.plot && ( void } + removePlotsButton?: { onClick: () => void } children: React.ReactNode } @@ -44,6 +48,13 @@ export const SectionDescription = { are enabled. ), + // "Custom" + [Section.CUSTOM_PLOTS]: ( + + Generated custom linear plots comparing chosen metrics and params in all + experiments in the table. + + ), // "Images" [Section.COMPARISON_TABLE]: ( @@ -83,7 +94,9 @@ export const PlotsContainer: React.FC = ({ title, children, currentSize, - menu + menu, + addPlotsButton, + removePlotsButton }) => { const open = !sectionCollapsed @@ -101,6 +114,22 @@ export const PlotsContainer: React.FC = ({ }) } + if (addPlotsButton) { + menuItems.unshift({ + icon: Add, + onClick: addPlotsButton.onClick, + tooltip: 'Add Plots' + }) + } + + if (removePlotsButton) { + menuItems.unshift({ + icon: Trash, + onClick: removePlotsButton.onClick, + tooltip: 'Remove Plots' + }) + } + const tooltipContent = (
diff --git a/webview/src/plots/components/checkpointPlots/checkpointPlotsSlice.ts b/webview/src/plots/components/checkpointPlots/checkpointPlotsSlice.ts index 88dba386e2..ea0a264a53 100644 --- a/webview/src/plots/components/checkpointPlots/checkpointPlotsSlice.ts +++ b/webview/src/plots/components/checkpointPlots/checkpointPlotsSlice.ts @@ -6,6 +6,7 @@ import { Section } from 'dvc/src/plots/webview/contract' import { addPlotsWithSnapshots, removePlots } from '../plotDataStore' + export interface CheckpointPlotsState extends Omit { isCollapsed: boolean diff --git a/webview/src/plots/components/customPlots/CustomPlot.tsx b/webview/src/plots/components/customPlots/CustomPlot.tsx new file mode 100644 index 0000000000..7dfe6badc3 --- /dev/null +++ b/webview/src/plots/components/customPlots/CustomPlot.tsx @@ -0,0 +1,51 @@ +import { Section } from 'dvc/src/plots/webview/contract' +import React, { useMemo, useEffect, useState } from 'react' +import { useSelector } from 'react-redux' +import { createSpec } from './util' +import { changeDisabledDragIds, changeSize } from './customPlotsSlice' +import { ZoomablePlot } from '../ZoomablePlot' +import styles from '../styles.module.scss' +import { withScale } from '../../../util/styles' +import { plotDataStore } from '../plotDataStore' +import { PlotsState } from '../../store' + +interface CustomPlotProps { + id: string +} + +export const CustomPlot: React.FC = ({ id }) => { + const plotSnapshot = useSelector( + (state: PlotsState) => state.custom.plotsSnapshots[id] + ) + const [plot, setPlot] = useState(plotDataStore[Section.CUSTOM_PLOTS][id]) + const currentSize = useSelector((state: PlotsState) => state.custom.size) + + const spec = useMemo(() => { + if (plot) { + return createSpec(plot.metric, plot.param) + } + }, [plot]) + + useEffect(() => { + setPlot(plotDataStore[Section.CUSTOM_PLOTS][id]) + }, [plotSnapshot, id]) + + if (!plot || !spec) { + return null + } + + const key = `plot-${id}` + + return ( +
+ +
+ ) +} diff --git a/webview/src/plots/components/customPlots/CustomPlots.tsx b/webview/src/plots/components/customPlots/CustomPlots.tsx new file mode 100644 index 0000000000..efc6d54064 --- /dev/null +++ b/webview/src/plots/components/customPlots/CustomPlots.tsx @@ -0,0 +1,101 @@ +import React, { DragEvent, useEffect, useState } from 'react' +import { useSelector } from 'react-redux' +import cx from 'classnames' +import { MessageFromWebviewType } from 'dvc/src/webview/contract' +import { performSimpleOrderedUpdate } from 'dvc/src/util/array' +import { CustomPlot } from './CustomPlot' +import styles from '../styles.module.scss' +import { EmptyState } from '../../../shared/components/emptyState/EmptyState' +import { + DragDropContainer, + WrapperProps +} from '../../../shared/components/dragDrop/DragDropContainer' +import { DropTarget } from '../DropTarget' +import { VirtualizedGrid } from '../../../shared/components/virtualizedGrid/VirtualizedGrid' +import { shouldUseVirtualizedGrid } from '../util' +import { PlotsState } from '../../store' +import { sendMessage } from '../../../shared/vscode' +import { changeOrderWithDraggedInfo } from '../../../util/array' + +interface CustomPlotsProps { + plotsIds: string[] +} + +export const CustomPlots: React.FC = ({ plotsIds }) => { + const [order, setOrder] = useState(plotsIds) + const { size, hasData, disabledDragPlotIds } = useSelector( + (state: PlotsState) => state.custom + ) + const [onSection, setOnSection] = useState(false) + const nbItemsPerRow = size + const draggedRef = useSelector( + (state: PlotsState) => state.dragAndDrop.draggedRef + ) + + useEffect(() => { + setOrder(pastOrder => performSimpleOrderedUpdate(pastOrder, plotsIds)) + }, [plotsIds]) + + const setPlotsIdsOrder = (order: string[]): void => { + setOrder(order) + sendMessage({ + payload: order, + type: MessageFromWebviewType.REORDER_PLOTS_CUSTOM + }) + } + + if (!hasData) { + return No Plots to Display + } + + const items = order.map(plot => ( +
+ +
+ )) + + const useVirtualizedGrid = shouldUseVirtualizedGrid(items.length, size) + + const handleDragOver = (e: DragEvent) => { + e.preventDefault() + setOnSection(true) + } + + const handleDropAtTheEnd = () => { + setPlotsIdsOrder(changeOrderWithDraggedInfo(order, draggedRef)) + } + + return items.length > 0 ? ( +
setOnSection(true)} + onDragLeave={() => setOnSection(false)} + onDragOver={handleDragOver} + onDrop={handleDropAtTheEnd} + > + } + wrapperComponent={ + useVirtualizedGrid + ? { + component: VirtualizedGrid as React.FC, + props: { nbItemsPerRow } + } + : undefined + } + parentDraggedOver={onSection} + /> +
+ ) : ( + No Plots Added + ) +} diff --git a/webview/src/plots/components/customPlots/CustomPlotsWrapper.tsx b/webview/src/plots/components/customPlots/CustomPlotsWrapper.tsx new file mode 100644 index 0000000000..584ba27f9c --- /dev/null +++ b/webview/src/plots/components/customPlots/CustomPlotsWrapper.tsx @@ -0,0 +1,40 @@ +import { Section } from 'dvc/src/plots/webview/contract' +import React, { useEffect, useState } from 'react' +import { useSelector } from 'react-redux' +import { MessageFromWebviewType } from 'dvc/src/webview/contract' +import { CustomPlots } from './CustomPlots' +import { PlotsContainer } from '../PlotsContainer' +import { PlotsState } from '../../store' +import { sendMessage } from '../../../shared/vscode' + +export const CustomPlotsWrapper: React.FC = () => { + const { plotsIds, size, isCollapsed } = useSelector( + (state: PlotsState) => state.custom + ) + const [selectedPlots, setSelectedPlots] = useState([]) + useEffect(() => { + setSelectedPlots(plotsIds) + }, [plotsIds, setSelectedPlots]) + const addCustomPlot = () => { + sendMessage({ type: MessageFromWebviewType.ADD_CUSTOM_PLOT }) + } + + const removeCustomPlots = () => { + sendMessage({ type: MessageFromWebviewType.REMOVE_CUSTOM_PLOTS }) + } + + return ( + 0 ? { onClick: removeCustomPlots } : undefined + } + > + + + ) +} diff --git a/webview/src/plots/components/customPlots/customPlotsSlice.ts b/webview/src/plots/components/customPlots/customPlotsSlice.ts new file mode 100644 index 0000000000..5b1900e528 --- /dev/null +++ b/webview/src/plots/components/customPlots/customPlotsSlice.ts @@ -0,0 +1,62 @@ +import { createSlice, PayloadAction } from '@reduxjs/toolkit' +import { + CustomPlotsData, + DEFAULT_SECTION_COLLAPSED, + DEFAULT_SECTION_SIZES, + Section +} from 'dvc/src/plots/webview/contract' +import { addPlotsWithSnapshots, removePlots } from '../plotDataStore' + +export interface CustomPlotsState extends Omit { + isCollapsed: boolean + hasData: boolean + plotsIds: string[] + plotsSnapshots: { [key: string]: string } + disabledDragPlotIds: string[] +} + +export const customPlotsInitialState: CustomPlotsState = { + disabledDragPlotIds: [], + hasData: false, + isCollapsed: DEFAULT_SECTION_COLLAPSED[Section.CUSTOM_PLOTS], + plotsIds: [], + plotsSnapshots: {}, + size: DEFAULT_SECTION_SIZES[Section.CUSTOM_PLOTS] +} + +export const customPlotsSlice = createSlice({ + initialState: customPlotsInitialState, + name: 'custom', + reducers: { + changeDisabledDragIds: (state, action: PayloadAction) => { + state.disabledDragPlotIds = action.payload + }, + changeSize: (state, action: PayloadAction) => { + state.size = action.payload + }, + setCollapsed: (state, action: PayloadAction) => { + state.isCollapsed = action.payload + }, + update: (state, action: PayloadAction) => { + if (!action.payload) { + return customPlotsInitialState + } + const { plots, ...statePayload } = action.payload + const plotsIds = plots?.map(plot => plot.id) || [] + const snapShots = addPlotsWithSnapshots(plots, Section.CUSTOM_PLOTS) + removePlots(plotsIds, Section.CUSTOM_PLOTS) + return { + ...state, + ...statePayload, + hasData: !!action.payload, + plotsIds: plots?.map(plot => plot.id) || [], + plotsSnapshots: snapShots + } + } + } +}) + +export const { update, setCollapsed, changeSize, changeDisabledDragIds } = + customPlotsSlice.actions + +export default customPlotsSlice.reducer diff --git a/webview/src/plots/components/customPlots/util.ts b/webview/src/plots/components/customPlots/util.ts new file mode 100644 index 0000000000..bbb7cfbd49 --- /dev/null +++ b/webview/src/plots/components/customPlots/util.ts @@ -0,0 +1,87 @@ +import { VisualizationSpec } from 'react-vega' + +export const createSpec = (metric: string, param: string) => + ({ + $schema: 'https://vega.github.io/schema/vega-lite/v5.json', + data: { name: 'values' }, + encoding: { + x: { + field: 'param', + title: param, + type: 'quantitative' + }, + y: { + field: 'metric', + scale: { zero: false }, + title: metric, + type: 'quantitative' + } + }, + height: 'container', + layer: [ + { + layer: [ + { + mark: { + type: 'line' + } + }, + { + mark: { + type: 'point' + }, + transform: [ + { + filter: { + param: 'hover' + } + } + ] + } + ] + }, + { + encoding: { + opacity: { + value: 0 + }, + tooltip: [ + { + field: 'expName', + title: 'name' + }, + { + field: 'metric', + title: metric + }, + { + field: 'param', + title: param + } + ] + }, + mark: { + type: 'rule' + }, + params: [ + { + name: 'hover', + select: { + clear: 'mouseout', + fields: ['param', 'metric'], + nearest: true, + on: 'mouseover', + type: 'point' + } + } + ] + } + ], + transform: [ + { + as: 'y', + calculate: "format(datum['y'],'.5f')" + } + ], + width: 'container' + } as VisualizationSpec) diff --git a/webview/src/plots/components/plotDataStore.ts b/webview/src/plots/components/plotDataStore.ts index 76e106ab0e..51436b9d96 100644 --- a/webview/src/plots/components/plotDataStore.ts +++ b/webview/src/plots/components/plotDataStore.ts @@ -1,20 +1,23 @@ import { CheckpointPlotData, + CustomPlotData, Section, TemplatePlotEntry } from 'dvc/src/plots/webview/contract' export type CheckpointPlotsById = { [key: string]: CheckpointPlotData } +export type CustomPlotsById = { [key: string]: CustomPlotData } export type TemplatePlotsById = { [key: string]: TemplatePlotEntry } export const plotDataStore = { [Section.CHECKPOINT_PLOTS]: {} as CheckpointPlotsById, [Section.TEMPLATE_PLOTS]: {} as TemplatePlotsById, - [Section.COMPARISON_TABLE]: {} as CheckpointPlotsById // This category is unused but exists only to make typings easier + [Section.COMPARISON_TABLE]: {} as CheckpointPlotsById, // This category is unused but exists only to make typings easier, + [Section.CUSTOM_PLOTS]: {} as CustomPlotsById } export const addPlotsWithSnapshots = ( - plots: (CheckpointPlotData | TemplatePlotEntry)[], + plots: (CheckpointPlotData | TemplatePlotEntry | CustomPlotData)[], section: Section ) => { const snapShots: { [key: string]: string } = {} diff --git a/webview/src/plots/hooks/useGetPlot.ts b/webview/src/plots/hooks/useGetPlot.ts index 74f7275cda..955390f0cf 100644 --- a/webview/src/plots/hooks/useGetPlot.ts +++ b/webview/src/plots/hooks/useGetPlot.ts @@ -1,5 +1,6 @@ import { CheckpointPlotData, + CustomPlotData, Section, TemplatePlotEntry } from 'dvc/src/plots/webview/contract' @@ -9,13 +10,25 @@ import { PlainObject, VisualizationSpec } from 'react-vega' import { plotDataStore } from '../components/plotDataStore' import { PlotsState } from '../store' +const getStoreSection = (section: Section) => { + switch (section) { + case Section.CHECKPOINT_PLOTS: + return 'checkpoint' + case Section.TEMPLATE_PLOTS: + return 'template' + default: + return 'custom' + } +} + export const useGetPlot = ( section: Section, id: string, spec?: VisualizationSpec ) => { - const isCheckpointPlot = section === Section.CHECKPOINT_PLOTS - const storeSection = isCheckpointPlot ? 'checkpoint' : 'template' + const isPlotWithSpec = + section === Section.CHECKPOINT_PLOTS || section === Section.CUSTOM_PLOTS + const storeSection = getStoreSection(section) const snapshot = useSelector( (state: PlotsState) => state[storeSection].plotsSnapshots ) @@ -28,8 +41,8 @@ export const useGetPlot = ( return } - if (isCheckpointPlot) { - setData({ values: (plot as CheckpointPlotData).values }) + if (isPlotWithSpec) { + setData({ values: (plot as CheckpointPlotData | CustomPlotData).values }) setContent(spec) return } @@ -40,7 +53,7 @@ export const useGetPlot = ( height: 'container', width: 'container' } as VisualizationSpec) - }, [id, isCheckpointPlot, setData, setContent, section, spec]) + }, [id, isPlotWithSpec, setData, setContent, section, spec]) useEffect(() => { setPlotData() diff --git a/webview/src/plots/store.ts b/webview/src/plots/store.ts index 57f1bb6cda..9686b1fda8 100644 --- a/webview/src/plots/store.ts +++ b/webview/src/plots/store.ts @@ -2,6 +2,7 @@ import { configureStore } from '@reduxjs/toolkit' import checkpointPlotsReducer from './components/checkpointPlots/checkpointPlotsSlice' import comparisonTableReducer from './components/comparisonTable/comparisonTableSlice' import templatePlotsReducer from './components/templatePlots/templatePlotsSlice' +import customPlotsReducer from './components/customPlots/customPlotsSlice' import webviewReducer from './components/webviewSlice' import ribbonReducer from './components/ribbon/ribbonSlice' import dragAndDropReducer from '../shared/components/dragDrop/dragDropSlice' @@ -9,6 +10,7 @@ import dragAndDropReducer from '../shared/components/dragDrop/dragDropSlice' export const plotsReducers = { checkpoint: checkpointPlotsReducer, comparison: comparisonTableReducer, + custom: customPlotsReducer, dragAndDrop: dragAndDropReducer, ribbon: ribbonReducer, template: templatePlotsReducer, diff --git a/webview/src/shared/components/icons/Trash.tsx b/webview/src/shared/components/icons/Trash.tsx new file mode 100644 index 0000000000..299c5a4366 --- /dev/null +++ b/webview/src/shared/components/icons/Trash.tsx @@ -0,0 +1,23 @@ +import * as React from 'react' + +function SvgTrash(props: React.SVGProps) { + return ( + + + + ) +} + +export default SvgTrash diff --git a/webview/src/shared/components/icons/index.ts b/webview/src/shared/components/icons/index.ts index 32032c343d..db0a6b4e15 100644 --- a/webview/src/shared/components/icons/index.ts +++ b/webview/src/shared/components/icons/index.ts @@ -22,3 +22,4 @@ export { default as UpArrow } from './UpArrow' export { default as SortPrecedence } from './SortPrecedence' export { default as StarFull } from './StarFull' export { default as StarEmpty } from './StarEmpty' +export { default as Trash } from './Trash' diff --git a/webview/src/stories/Plots.stories.tsx b/webview/src/stories/Plots.stories.tsx index d8474f7c28..3efbfc6742 100644 --- a/webview/src/stories/Plots.stories.tsx +++ b/webview/src/stories/Plots.stories.tsx @@ -12,6 +12,7 @@ import { } from 'dvc/src/plots/webview/contract' import { MessageToWebviewType } from 'dvc/src/webview/contract' import checkpointPlotsFixture from 'dvc/src/test/fixtures/expShow/base/checkpointPlots' +import customPlotsFixture from 'dvc/src/test/fixtures/expShow/base/customPlots' import templatePlotsFixture from 'dvc/src/test/fixtures/plotsDiff/template' import manyTemplatePlots from 'dvc/src/test/fixtures/plotsDiff/template/virtualization' import comparisonPlotsFixture from 'dvc/src/test/fixtures/plotsDiff/comparison' @@ -66,6 +67,7 @@ export default { data: { checkpoint: checkpointPlotsFixture, comparison: comparisonPlotsFixture, + custom: customPlotsFixture, hasPlots: true, hasUnselectedPlots: false, sectionCollapsed: DEFAULT_SECTION_COLLAPSED, @@ -115,6 +117,16 @@ WithCheckpointOnly.args = { } WithCheckpointOnly.parameters = DISABLE_CHROMATIC_SNAPSHOTS +export const WithCustomOnly = Template.bind({}) +WithCustomOnly.args = { + data: { + custom: customPlotsFixture, + sectionCollapsed: DEFAULT_SECTION_COLLAPSED, + selectedRevisions: plotsRevisionsFixture + } +} +WithCustomOnly.parameters = DISABLE_CHROMATIC_SNAPSHOTS + export const WithTemplateOnly = Template.bind({}) WithTemplateOnly.args = { data: { @@ -174,6 +186,7 @@ AllLarge.args = { data: { checkpoint: { ...checkpointPlotsFixture, size: PlotSizeNumber.LARGE }, comparison: { ...comparisonPlotsFixture, size: PlotSizeNumber.LARGE }, + custom: { ...customPlotsFixture, size: PlotSizeNumber.LARGE }, sectionCollapsed: DEFAULT_SECTION_COLLAPSED, selectedRevisions: plotsRevisionsFixture, template: { ...templatePlotsFixture, size: PlotSizeNumber.LARGE } @@ -186,6 +199,7 @@ AllSmall.args = { data: { checkpoint: smallCheckpointPlotsFixture, comparison: { ...comparisonPlotsFixture, size: PlotSizeNumber.SMALL }, + custom: { ...customPlotsFixture, size: PlotSizeNumber.SMALL }, sectionCollapsed: DEFAULT_SECTION_COLLAPSED, selectedRevisions: plotsRevisionsFixture, template: { ...templatePlotsFixture, size: PlotSizeNumber.SMALL } @@ -202,6 +216,7 @@ VirtualizedPlots.args = { selectedMetrics: manyCheckpointPlotsFixture.map(plot => plot.id) }, comparison: undefined, + custom: customPlotsFixture, sectionCollapsed: DEFAULT_SECTION_COLLAPSED, selectedRevisions: plotsRevisionsFixture, template: manyTemplatePlots(125) @@ -269,6 +284,7 @@ ScrolledWithManyRevisions.args = { data: { checkpoint: checkpointPlotsFixture, comparison: comparisonPlotsFixture, + custom: customPlotsFixture, hasPlots: true, hasUnselectedPlots: false, sectionCollapsed: DEFAULT_SECTION_COLLAPSED,