From f28d2d16935927aea1db89ea5a97ec952be5ebba Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 29 Mar 2023 13:13:34 +1100 Subject: [PATCH] Ensure that data for broken revisions is dropped --- extension/src/plots/model/index.ts | 12 +++- extension/src/plots/paths/collect.test.ts | 39 ++++++++--- extension/src/plots/paths/collect.ts | 28 ++++---- extension/src/test/suite/plots/index.test.ts | 74 +++++++++++++++++++- 4 files changed, 126 insertions(+), 27 deletions(-) diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index 636700422c..45e09f0bea 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -124,7 +124,7 @@ export class PlotsModel extends ModelWithPersistence { if (isDvcError(output)) { this.handleCliError() } else { - await this.processOutput(output, cliIdToLabel) + await this.processOutput(output, revs, cliIdToLabel) } this.setComparisonOrder() @@ -403,8 +403,13 @@ export class PlotsModel extends ModelWithPersistence { private async processOutput( output: PlotsOutput, + revs: string[], cliIdToLabel: CLIRevisionIdToLabel ) { + for (const rev of revs) { + this.deleteRevisionData(cliIdToLabel[rev] || rev) + } + const [{ comparisonData, revisionData }, templates, multiSourceVariations] = await Promise.all([ collectData(output, cliIdToLabel), @@ -420,7 +425,7 @@ export class PlotsModel extends ModelWithPersistence { ...this.revisionData, ...revisionData } - this.templates = { ...this.templates, ...templates } + this.templates = templates this.multiSourceVariations = multiSourceVariations this.multiSourceEncoding = collectMultiSourceEncoding( this.multiSourceVariations @@ -468,10 +473,12 @@ export class PlotsModel extends ModelWithPersistence { for (const id of Object.keys(this.commitRevisions)) { if (this.commitRevisions[id] !== currentCommitRevisions[id]) { this.deleteRevisionData(id) + this.fetchedRevs.delete(id) } } if (!isEqual(this.commitRevisions, currentCommitRevisions)) { this.deleteRevisionData(EXPERIMENT_WORKSPACE_ID) + this.fetchedRevs.delete(EXPERIMENT_WORKSPACE_ID) } this.commitRevisions = currentCommitRevisions } @@ -479,7 +486,6 @@ export class PlotsModel extends ModelWithPersistence { private deleteRevisionData(id: string) { delete this.revisionData[id] delete this.comparisonData[id] - this.fetchedRevs.delete(id) } private getCLIId(label: string) { diff --git a/extension/src/plots/paths/collect.test.ts b/extension/src/plots/paths/collect.test.ts index ac5812e58c..a42a44c819 100644 --- a/extension/src/plots/paths/collect.test.ts +++ b/extension/src/plots/paths/collect.test.ts @@ -12,6 +12,7 @@ import { TemplatePlotGroup, PlotsType } from '../webview/contract' import plotsDiffFixture from '../../test/fixtures/plotsDiff/output' import { Shape, StrokeDash } from '../multiSource/constants' import { EXPERIMENT_WORKSPACE_ID } from '../../cli/dvc/contract' +import { CLIRevisionIdToLabel } from '../model/collect' describe('collectPaths', () => { const revisions = [ @@ -80,13 +81,21 @@ describe('collectPaths', () => { ]) }) - it('should update the revision details when the workspace is recollected (plots in workspace changed)', () => { + it('should update the revision details when any revision is recollected', () => { const [remainingPath] = Object.keys(plotsDiffFixture) const collectedPaths = collectPaths([], plotsDiffFixture, revisions, {}) expect( collectedPaths.filter(path => path.revisions.has(EXPERIMENT_WORKSPACE_ID)) ).toHaveLength(collectedPaths.length) + const fetchedRevs = revisions.slice(0, 3) + const cliIdToLabel: CLIRevisionIdToLabel = {} + for (const rev of fetchedRevs) { + cliIdToLabel[rev] = rev + } + + cliIdToLabel[fetchedRevs[2]] = 'some-branch' + const updatedPaths = collectPaths( collectedPaths, { @@ -95,26 +104,40 @@ describe('collectPaths', () => { { content: {}, datapoints: { - [EXPERIMENT_WORKSPACE_ID]: [ + [fetchedRevs[0]]: [ + { + loss: '2.43323', + step: '0' + } + ], + [fetchedRevs[1]]: [ + { + loss: '2.43323', + step: '0' + } + ], + [fetchedRevs[2]]: [ { loss: '2.43323', step: '0' } ] }, - revisions: [EXPERIMENT_WORKSPACE_ID], + revisions: fetchedRevs, type: PlotsType.VEGA } ] } }, - [EXPERIMENT_WORKSPACE_ID], - {} + fetchedRevs, + cliIdToLabel ) - expect( - updatedPaths.filter(path => path.revisions.has(EXPERIMENT_WORKSPACE_ID)) - ).toHaveLength(remainingPath.split(sep).length) + for (const rev of Object.values(cliIdToLabel)) { + expect(updatedPaths.filter(path => path.revisions.has(rev))).toHaveLength( + remainingPath.split(sep).length + ) + } }) it('should not drop already collected paths', () => { diff --git a/extension/src/plots/paths/collect.ts b/extension/src/plots/paths/collect.ts index a64b8d46f5..3bb1e868bd 100644 --- a/extension/src/plots/paths/collect.ts +++ b/extension/src/plots/paths/collect.ts @@ -5,12 +5,7 @@ import { TemplatePlot, TemplatePlotGroup } from '../webview/contract' -import { - EXPERIMENT_WORKSPACE_ID, - PlotError, - PlotsData, - PlotsOutput -} from '../../cli/dvc/contract' +import { PlotError, PlotsData, PlotsOutput } from '../../cli/dvc/contract' import { getParent, getPath, getPathArray } from '../../fileSystem/util' import { splitMatchedOrdered, definedAndNonEmpty } from '../../util/array' import { isMultiViewPlot } from '../vega/util' @@ -22,6 +17,7 @@ import { StrokeDashValue } from '../multiSource/constants' import { MultiSourceEncoding } from '../multiSource/collect' +import { CLIRevisionIdToLabel } from '../model/collect' export enum PathType { COMPARISON = 'comparison', @@ -71,17 +67,17 @@ const getType = ( return collectType(plots) } -const filterWorkspaceIfFetched = ( +const filterRevisionIfFetched = ( existingPaths: PlotPath[], - fetchedRevs: string[] + fetchedRevs: string[], + cliIdToLabel: CLIRevisionIdToLabel ) => { - if (!fetchedRevs.includes(EXPERIMENT_WORKSPACE_ID)) { - return existingPaths - } - return existingPaths.map(existing => { const revisions = existing.revisions - revisions.delete(EXPERIMENT_WORKSPACE_ID) + for (const rev of fetchedRevs) { + const id = cliIdToLabel[rev] || rev + revisions.delete(id) + } return { ...existing, revisions } }) } @@ -233,7 +229,11 @@ export const collectPaths = ( fetchedRevs: string[], cliIdToLabel: { [id: string]: string } ): PlotPath[] => { - let acc: PlotPath[] = filterWorkspaceIfFetched(existingPaths, fetchedRevs) + let acc: PlotPath[] = filterRevisionIfFetched( + existingPaths, + fetchedRevs, + cliIdToLabel + ) const { data, errors } = output diff --git a/extension/src/test/suite/plots/index.test.ts b/extension/src/test/suite/plots/index.test.ts index cd82afb0e1..16af64ece5 100644 --- a/extension/src/test/suite/plots/index.test.ts +++ b/extension/src/test/suite/plots/index.test.ts @@ -5,6 +5,7 @@ import { afterEach, beforeEach, describe, it, suite } from 'mocha' import { expect } from 'chai' import { restore, spy, stub } from 'sinon' import { commands, Uri } from 'vscode' +import isEqual from 'lodash.isequal' import { buildPlots } from '../plots/util' import { Disposable } from '../../../extension' import expShowFixtureWithoutErrors from '../../fixtures/expShow/base/noErrors' @@ -31,12 +32,14 @@ import { PlotsSection, TemplatePlotGroup, TemplatePlotsData, - CustomPlotType + CustomPlotType, + TemplatePlot, + ImagePlot } from '../../../plots/webview/contract' import { TEMP_PLOTS_DIR } from '../../../cli/dvc/constants' import { WEBVIEW_TEST_TIMEOUT } from '../timeouts' import { MessageFromWebviewType } from '../../../webview/contract' -import { reorderObjectList } from '../../../util/array' +import { reorderObjectList, uniqueValues } from '../../../util/array' import * as Telemetry from '../../../telemetry' import { EventName } from '../../../telemetry/constants' import { @@ -686,6 +689,73 @@ suite('Plots Test Suite', () => { expect(webview.isVisible()).to.be.true }).timeout(WEBVIEW_TEST_TIMEOUT) + it("should remove a revision's data if the revision is re-fetched and now contains an error", async () => { + const accPngPath = join('plots', 'acc.png') + const accPng = [ + ...plotsDiffFixture.data[join('plots', 'acc.png')] + ] as ImagePlot[] + const lossTsvPath = join('logs', 'loss.tsv') + const lossTsv = [...plotsDiffFixture.data[lossTsvPath]] as TemplatePlot[] + + const plotsDiffOutput = { + data: { + [accPngPath]: accPng, + [lossTsvPath]: lossTsv + } + } + + const brokenRev = '4fb124a' + + const reFetchedOutput = { + data: { + [accPngPath]: accPng.filter( + ({ revisions }) => !isEqual(revisions, [brokenRev]) + ), + [lossTsvPath]: lossTsv.map((plot, i) => { + const datapoints = { ...lossTsv[i].datapoints } + delete datapoints[brokenRev] + + return { + ...plot, + datapoints, + revisions: lossTsv[i].revisions?.filter(rev => rev !== brokenRev) + } + }) + } + } + + const { mockPlotsDiff, plots, data, plotsModel } = await buildPlots( + disposable, + plotsDiffOutput + ) + + await plots.isReady() + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const getExistingRevisions = (plotsModel: any) => + uniqueValues([ + ...Object.keys(plotsModel.revisionData), + ...Object.keys(plotsModel.comparisonData) + ]) + + expect(getExistingRevisions(plotsModel)).to.contain(brokenRev) + + mockPlotsDiff.resetBehavior() + mockPlotsDiff.resolves(reFetchedOutput) + + const dataUpdated = new Promise(resolve => + data.onDidUpdate(() => resolve(undefined)) + ) + + await data.update() + await dataUpdated + + expect( + getExistingRevisions(plotsModel), + 'the revision should not exist in the underlying data' + ).not.to.contain(brokenRev) + }) + it('should send the correct data to the webview for flexible plots', async () => { const { experiments, plots, messageSpy, mockPlotsDiff } = await buildPlots(disposable, multiSourcePlotsDiffFixture)