Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use linear path to update plots data when experiments update #2831

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extension/src/cli/dvc/reader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ export class DvcReader extends DvcCli {
} catch (error: unknown) {
const msg =
(error as MaybeConsoleError).stderr || (error as Error).message
Logger.error(`${args} failed with ${msg} retrying...`)
Logger.error(`${args} failed with ${msg}`)
return { error: { msg, type: 'Caught error' } }
}
}
Expand Down
33 changes: 21 additions & 12 deletions extension/src/plots/data/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { EventEmitter } from 'vscode'
import { PlotsOutput } from '../../cli/dvc/contract'
import { AvailableCommands, InternalCommands } from '../../commands/internal'
import { BaseData } from '../../data'
import { Experiments } from '../../experiments'
import {
definedAndNonEmpty,
flattenUnique,
Expand All @@ -10,12 +11,14 @@ import {
import { PlotsModel } from '../model'

export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> {
private readonly model: PlotsModel
private readonly plots: PlotsModel
private readonly experiments: Experiments

constructor(
dvcRoot: string,
internalCommands: InternalCommands,
model: PlotsModel,
plots: PlotsModel,
experiments: Experiments,
updatesPaused: EventEmitter<boolean>
) {
super(
Expand All @@ -30,13 +33,15 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> {
],
['dvc.yaml', 'dvc.lock']
)
this.model = model

this.plots = plots
this.experiments = experiments
}

public async update(): Promise<void> {
const revs = flattenUnique([
this.model.getMissingRevisions(),
this.model.getMutableRevisions()
this.plots.getMissingRevisions(this.experiments.getSelectedRevisions()),
this.experiments.getMutableRevisions()
])

if (
Expand All @@ -49,17 +54,13 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> {
}

const args = this.getArgs(revs)
const data = await this.internalCommands.executeCommand<PlotsOutput>(
AvailableCommands.PLOTS_DIFF,
this.dvcRoot,
...args
)
const data = await this.fetch(args)

const files = this.collectFiles({ data })

this.compareFiles(files)

return this.notifyChanged({ data, revs })
this.notifyChanged({ data, revs })
}

public managedUpdate() {
Expand All @@ -70,9 +71,17 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> {
return Object.keys(data)
}

public fetch(revs: string[]) {
return this.internalCommands.executeCommand<PlotsOutput>(
AvailableCommands.PLOTS_DIFF,
this.dvcRoot,
...revs
)
}

private getArgs(revs: string[]) {
const cliWillThrowError = sameContents(revs, ['workspace'])
if (this.model && cliWillThrowError) {
if (this.plots && cliWillThrowError) {
return []
}

Expand Down
81 changes: 63 additions & 18 deletions extension/src/plots/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ import { BaseRepository } from '../webview/repository'
import { Experiments } from '../experiments'
import { Resource } from '../resourceLocator'
import { InternalCommands } from '../commands/internal'
import { definedAndNonEmpty } from '../util/array'
import { definedAndNonEmpty, sameContents } from '../util/array'
import { ExperimentsOutput } from '../cli/dvc/contract'
import { TEMP_PLOTS_DIR } from '../cli/dvc/constants'
import { removeDir } from '../fileSystem'
import { Toast } from '../vscode/toast'
import { pickPaths } from '../path/selection/quickPick'
import { SelectedExperimentWithColor } from '../experiments/model'

export type PlotsWebview = BaseWebview<TPlotsData>

Expand Down Expand Up @@ -48,7 +49,7 @@ export class Plots extends BaseRepository<TPlotsData> {
this.experiments = experiments

this.plots = this.dispose.track(
new PlotsModel(this.dvcRoot, experiments, workspaceState)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[F] Removing experiments from here so that we can't get into a race condition where the selected revisions change mid-update.

new PlotsModel(this.dvcRoot, workspaceState)
)
this.paths = this.dispose.track(
new PathsModel(this.dvcRoot, workspaceState)
Expand All @@ -61,7 +62,13 @@ export class Plots extends BaseRepository<TPlotsData> {
)

this.data = this.dispose.track(
new PlotsData(dvcRoot, internalCommands, this.plots, updatesPaused)
new PlotsData(
dvcRoot,
internalCommands,
this.plots,
this.experiments,
updatesPaused
)
)

this.onDidUpdateData()
Expand All @@ -77,7 +84,7 @@ export class Plots extends BaseRepository<TPlotsData> {
}

public sendInitialWebviewData() {
return this.fetchMissingOrSendPlots()
return this.fetchMissingAndSendPlots()
}

public togglePathStatus(path: string) {
Expand All @@ -104,7 +111,9 @@ export class Plots extends BaseRepository<TPlotsData> {
Toast.infoWithOptions(
'Attempting to refresh plots for selected experiments.'
)
for (const { revision } of this.plots.getSelectedRevisionDetails()) {
for (const { revision } of this.plots.getSelectedRevisionDetails(
this.experiments.getSelectedRevisions()
)) {
this.plots.setupManualRefresh(revision)
}
this.data.managedUpdate()
Expand All @@ -130,21 +139,42 @@ export class Plots extends BaseRepository<TPlotsData> {

private notifyChanged() {
this.pathsChanged.fire()
this.fetchMissingOrSendPlots()
this.fetchMissingAndSendPlots()
}

private async fetchMissingOrSendPlots() {
private async fetchMissingAndSendPlots(
overrideRevs?: SelectedExperimentWithColor[]
) {
await this.isReady()

const selectedRevs = overrideRevs || this.experiments.getSelectedRevisions()
const selectedExperiments = this.experiments.getSelectedExperiments()

const missingRevs = this.plots.getMissingRevisions(selectedRevs)

if (this.paths.hasPaths() && definedAndNonEmpty(missingRevs)) {
const data = await this.data.fetch(missingRevs)

await Promise.all([
this.plots.transformAndSetPlots(data, missingRevs, selectedRevs),
this.paths.transformAndSet(data)
])
}

this.webviewMessages.sendWebviewMessage(selectedRevs, selectedExperiments)
return this.syncWebview(selectedRevs)
}

private syncWebview(selectedRevs: SelectedExperimentWithColor[]) {
if (
this.paths.hasPaths() &&
definedAndNonEmpty(this.plots.getUnfetchedRevisions())
sameContents(
this.experiments.getSelectedRevisions().map(({ label }) => label),
selectedRevs.map(({ label }) => label)
)
) {
this.webviewMessages.sendCheckpointPlotsMessage()
return this.data.managedUpdate()
return
}

return this.webviewMessages.sendWebviewMessage()
this.fetchMissingAndSendPlots()
}

private createWebviewMessageHandler(
Expand Down Expand Up @@ -184,19 +214,30 @@ export class Plots extends BaseRepository<TPlotsData> {
private setupExperimentsListener(experiments: Experiments) {
this.dispose.track(
experiments.onDidChangeExperiments(async data => {
const selectedRevs = experiments.getSelectedRevisions()
if (data) {
await this.plots.transformAndSetExperiments(data)
await this.plots.transformAndSetExperiments(
data,
experiments.getBranchRevisions(),
experiments.getRevisions(),
experiments.hasCheckpoints()
)
}

this.plots.setComparisonOrder()
this.plots.setComparisonOrder(selectedRevs)

this.fetchMissingOrSendPlots()
this.fetchMissingAndSendPlots(selectedRevs)
})
)
}

private async initializeData(data: ExperimentsOutput) {
await this.plots.transformAndSetExperiments(data)
await this.plots.transformAndSetExperiments(
data,
this.experiments.getBranchRevisions(),
this.experiments.getRevisions(),
this.experiments.hasCheckpoints()
)
this.data.managedUpdate()
await Promise.all([
this.data.isReady(),
Expand All @@ -210,7 +251,11 @@ export class Plots extends BaseRepository<TPlotsData> {
this.dispose.track(
this.data.onDidUpdate(async ({ data, revs }) => {
await Promise.all([
this.plots.transformAndSetPlots(data, revs),
this.plots.transformAndSetPlots(
data,
revs,
this.experiments.getSelectedRevisions()
),
this.paths.transformAndSet(data)
])
this.notifyChanged()
Expand Down
56 changes: 23 additions & 33 deletions extension/src/plots/model/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ import {
Section
} from '../webview/contract'
import { buildMockMemento } from '../../test/util'
import { Experiments } from '../../experiments'
import { PersistenceKey } from '../../persistence/constants'
import { SelectedExperimentWithColor } from '../../experiments/model'

const mockedRevisions = [
{ displayColor: 'white', label: 'workspace' },
{ displayColor: 'red', label: 'main' },
{ displayColor: 'blue', label: '71f31cf' },
{ displayColor: 'black', label: 'e93c7e6' },
{ displayColor: 'brown', label: 'ffbe811' }
]
] as unknown as SelectedExperimentWithColor[]

describe('plotsModel', () => {
let model: PlotsModel
Expand All @@ -26,17 +26,9 @@ describe('plotsModel', () => {
persistedSelectedMetrics,
[PersistenceKey.PLOT_SIZES + exampleDvcRoot]: DEFAULT_SECTION_SIZES
})
const mockedGetSelectedRevisions = jest.fn()

beforeEach(() => {
model = new PlotsModel(
exampleDvcRoot,
{
getSelectedRevisions: mockedGetSelectedRevisions,
isReady: () => Promise.resolve(undefined)
} as unknown as Experiments,
memento
)
model = new PlotsModel(exampleDvcRoot, memento)
jest.clearAllMocks()
})

Expand Down Expand Up @@ -112,11 +104,9 @@ describe('plotsModel', () => {
})

it('should reorder comparison revisions after receiving a message to reorder', () => {
mockedGetSelectedRevisions.mockReturnValue(mockedRevisions)

const mementoUpdateSpy = jest.spyOn(memento, 'update')
const newOrder = ['71f31cf', 'e93c7e6', 'ffbe811', 'workspace', 'main']
model.setComparisonOrder(newOrder)
model.setComparisonOrder(mockedRevisions, newOrder)

expect(mementoUpdateSpy).toHaveBeenCalledTimes(1)
expect(mementoUpdateSpy).toHaveBeenCalledWith(
Expand All @@ -125,19 +115,21 @@ describe('plotsModel', () => {
)

expect(
model.getSelectedRevisionDetails().map(({ revision }) => revision)
model
.getSelectedRevisionDetails(mockedRevisions)
.map(({ revision }) => revision)
).toStrictEqual(newOrder)
})

it('should always send new revisions to the end of the list', () => {
mockedGetSelectedRevisions.mockReturnValue(mockedRevisions)

const newOrder = ['71f31cf', 'e93c7e6']

model.setComparisonOrder(newOrder)
model.setComparisonOrder(mockedRevisions, newOrder)

expect(
model.getSelectedRevisionDetails().map(({ revision }) => revision)
model
.getSelectedRevisionDetails(mockedRevisions)
.map(({ revision }) => revision)
).toStrictEqual([
...newOrder,
...mockedRevisions
Expand All @@ -151,31 +143,29 @@ describe('plotsModel', () => {
const revisionDropped = allRevisions.filter(({ label }) => label !== 'main')
const revisionReAdded = allRevisions

mockedGetSelectedRevisions
.mockReturnValueOnce(allRevisions)
.mockReturnValueOnce(allRevisions)
.mockReturnValueOnce(revisionDropped)
.mockReturnValueOnce(revisionDropped)
.mockReturnValueOnce(revisionReAdded)
.mockReturnValueOnce(revisionReAdded)

const initialOrder = ['workspace', 'main', '71f31cf']
model.setComparisonOrder(initialOrder)
model.setComparisonOrder(allRevisions, initialOrder)

expect(
model.getSelectedRevisionDetails().map(({ revision }) => revision)
model
.getSelectedRevisionDetails(allRevisions)
.map(({ revision }) => revision)
).toStrictEqual(initialOrder)

model.setComparisonOrder()
model.setComparisonOrder(revisionDropped)

expect(
model.getSelectedRevisionDetails().map(({ revision }) => revision)
model
.getSelectedRevisionDetails(revisionDropped)
.map(({ revision }) => revision)
).toStrictEqual(initialOrder.filter(revision => revision !== 'main'))

model.setComparisonOrder()
model.setComparisonOrder(revisionReAdded)

expect(
model.getSelectedRevisionDetails().map(({ revision }) => revision)
model
.getSelectedRevisionDetails(revisionReAdded)
.map(({ revision }) => revision)
).toStrictEqual(['workspace', '71f31cf', 'main'])
})
})
Loading