Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "Image by Step" plots #4372

Merged
merged 25 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions extension/src/cli/dvc/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ export const EXP_RWLOCK_FILE = join(TEMP_EXP_DIR, 'rwlock.lock')
export const DEFAULT_NUM_OF_COMMITS_TO_SHOW = 3
export const NUM_OF_COMMITS_TO_INCREASE = 2

export const MULTI_IMAGE_PATH_REG = /\w+[/\\|]\d+\.[a-z]+$/i

export enum Command {
ADD = 'add',
CHECKOUT = 'checkout',
Expand Down
35 changes: 32 additions & 3 deletions extension/src/plots/model/collect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
collectImageUrl
} from './collect'
import plotsDiffFixture from '../../test/fixtures/plotsDiff/output'
import multiImagePlotsDiffFixture from '../../test/fixtures/plotsDiff/output/multiImage'
import customPlotsFixture, {
customPlotsOrderFixture,
experimentsWithCommits
Expand Down Expand Up @@ -91,9 +92,37 @@ describe('collectData', () => {

expect(testBranchHeatmap).toBeDefined()
expect(testBranchHeatmap).toStrictEqual([
plotsDiffFixture.data[heatmapPlot].find(({ revisions }) =>
sameContents(revisions as string[], ['test-branch'])
)
{
...plotsDiffFixture.data[heatmapPlot].find(({ revisions }) =>
sameContents(revisions as string[], ['test-branch'])
),
path: heatmapPlot
}
])
})

it('should return the expected output from the comparison multi image test fixture', () => {
const { comparisonData } = collectData(multiImagePlotsDiffFixture)

const heatmapPlot = join('plots', 'heatmap.png')

expect(Object.keys(comparisonData.main)).toStrictEqual([
join('plots', 'acc.png'),
heatmapPlot,
join('plots', 'loss.png'),
join('plots', 'image')
])

const testBranchHeatmap = comparisonData['test-branch'][heatmapPlot]

expect(testBranchHeatmap).toBeDefined()
expect(testBranchHeatmap).toStrictEqual([
{
...plotsDiffFixture.data[heatmapPlot].find(({ revisions }) =>
sameContents(revisions as string[], ['test-branch'])
),
path: heatmapPlot
}
])
})
})
Expand Down
24 changes: 20 additions & 4 deletions extension/src/plots/model/collect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import {
import { StrokeDashEncoding } from '../multiSource/constants'
import { exists } from '../../fileSystem'
import { hasKey } from '../../util/object'
import { MULTI_IMAGE_PATH_REG } from '../../cli/dvc/constants'
import { getParent, getPathArray } from '../../fileSystem/util'

export const getCustomPlotId = (metric: string, param: string) =>
`custom-${metric}-${param}`
Expand Down Expand Up @@ -130,7 +132,7 @@ export type RevisionData = {

export type ComparisonData = {
[label: string]: {
[path: string]: ImagePlot[]
[path: string]: (ImagePlot & { path: string })[]
}
}

Expand All @@ -139,7 +141,9 @@ const collectImageData = (
path: string,
plot: ImagePlot
) => {
const pathLabel = path
const pathLabel = MULTI_IMAGE_PATH_REG.test(path)
? (getParent(getPathArray(path), 0) as string)
: path
const id = plot.revisions?.[0]
if (!id) {
return
Expand All @@ -153,7 +157,7 @@ const collectImageData = (
acc[id][pathLabel] = []
}

acc[id][pathLabel].push(plot)
acc[id][pathLabel].push({ ...plot, path })
}

const collectDatapoints = (
Expand Down Expand Up @@ -209,6 +213,16 @@ const collectPathData = (acc: DataAccumulator, path: string, plots: Plot[]) => {
}
}

const sortComparisonImgPaths = (acc: DataAccumulator) => {
for (const [label, paths] of Object.entries(acc.comparisonData)) {
for (const path of Object.keys(paths)) {
acc.comparisonData[label][path].sort((img1, img2) =>
img1.path.localeCompare(img2.path, undefined, { numeric: true })
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dvc plots diff doesn't give you directly images in the step order so we need to sort them on our end.

)
}
}
}

export const collectData = (output: PlotsOutput): DataAccumulator => {
const { data } = output
const acc = {
Expand All @@ -220,6 +234,8 @@ export const collectData = (output: PlotsOutput): DataAccumulator => {
collectPathData(acc, path, plots)
}

sortComparisonImgPaths(acc)

return acc
}

Expand Down Expand Up @@ -248,7 +264,7 @@ const collectSelectedPathComparisonPlots = ({
}

for (const id of selectedRevisionIds) {
const imgs = comparisonData[id][path]
const imgs = comparisonData[id]?.[path]
pathRevisions.revisions[id] = {
id,
imgs: imgs
Expand Down
42 changes: 42 additions & 0 deletions extension/src/plots/paths/collect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
} from './collect'
import { TemplatePlotGroup, PlotsType } from '../webview/contract'
import plotsDiffFixture from '../../test/fixtures/plotsDiff/output'
import multiImagePlotsDiffFixture from '../../test/fixtures/plotsDiff/output/multiImage'
import { Shape, StrokeDash } from '../multiSource/constants'
import { EXPERIMENT_WORKSPACE_ID } from '../../cli/dvc/contract'
import { REVISIONS } from '../../test/fixtures/plotsDiff'
Expand Down Expand Up @@ -75,6 +76,47 @@ describe('collectPaths', () => {
])
})

it('should return the expected data from the comparison multi image test fixture', () => {
expect(
collectPaths([], multiImagePlotsDiffFixture, REVISIONS)
).toStrictEqual([
{
hasChildren: false,
parentPath: 'plots',
path: join('plots', 'acc.png'),
revisions: new Set(REVISIONS),
type: new Set(['comparison'])
},
{
hasChildren: true,
parentPath: undefined,
path: 'plots',
revisions: new Set(REVISIONS)
},
{
hasChildren: false,
parentPath: 'plots',
path: join('plots', 'heatmap.png'),
revisions: new Set(REVISIONS),
type: new Set(['comparison'])
},
{
hasChildren: false,
parentPath: 'plots',
path: join('plots', 'loss.png'),
revisions: new Set(REVISIONS),
type: new Set(['comparison'])
},
{
hasChildren: false,
parentPath: 'plots',
path: join('plots', 'image'),
revisions: new Set(REVISIONS),
type: new Set(['comparison'])
}
])
})

it('should update the revision details when any revision is recollected', () => {
const [remainingPath] = Object.keys(plotsDiffFixture)
const collectedPaths = collectPaths([], plotsDiffFixture, REVISIONS)
Expand Down
31 changes: 27 additions & 4 deletions extension/src/plots/paths/collect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
} from '../multiSource/constants'
import { MultiSourceEncoding } from '../multiSource/collect'
import { truncate } from '../../util/string'
import { MULTI_IMAGE_PATH_REG } from '../../cli/dvc/constants'

export enum PathType {
COMPARISON = 'comparison',
Expand Down Expand Up @@ -58,8 +59,13 @@ const collectType = (plots: Plot[]) => {
const getType = (
data: PlotsData,
hasChildren: boolean,
path: string
path: string,
isMultiImgPlot?: boolean
): Set<PathType> | undefined => {
if (isMultiImgPlot) {
return new Set<PathType>([PathType.COMPARISON])
}

if (hasChildren) {
return
}
Expand Down Expand Up @@ -123,7 +129,8 @@ const collectOrderedPath = (
data: PlotsData,
revisions: Set<string>,
pathArray: string[],
idx: number
idx: number,
isMultiImgDir: boolean
): PlotPath[] => {
const path = getPath(pathArray, idx)

Expand All @@ -147,7 +154,9 @@ const collectOrderedPath = (
revisions
}

const type = getType(data, hasChildren, path)
const isMultiImgPlot = isMultiImgDir && idx === pathArray.length

const type = getType(data, hasChildren, path, isMultiImgPlot)
if (type) {
plotPath.type = type
}
Expand All @@ -167,9 +176,23 @@ const addRevisionsToPath = (
}

const pathArray = getPathArray(path)
const isMultiImg =
MULTI_IMAGE_PATH_REG.test(path) &&
!!getType(data, false, path)?.has(PathType.COMPARISON)

if (isMultiImg) {
pathArray.pop()
}

for (let reverseIdx = pathArray.length; reverseIdx > 0; reverseIdx--) {
acc = collectOrderedPath(acc, data, revisions, pathArray, reverseIdx)
acc = collectOrderedPath(
acc,
data,
revisions,
pathArray,
reverseIdx,
isMultiImg
)
}
return acc
}
Expand Down
26 changes: 26 additions & 0 deletions extension/src/plots/paths/model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { join } from 'path'
import { PathsModel } from './model'
import { PathType } from './collect'
import plotsDiffFixture from '../../test/fixtures/plotsDiff/output'
import plotsDiffMultiImgFixture from '../../test/fixtures/plotsDiff/output/multiImage'
import { buildMockMemento } from '../../test/util'
import { PlotsType, TemplatePlotGroup } from '../webview/contract'
import { EXPERIMENT_WORKSPACE_ID } from '../../cli/dvc/contract'
Expand Down Expand Up @@ -354,6 +355,31 @@ describe('PathsModel', () => {
expect(model.getComparisonPaths()).toStrictEqual(newOrder)
})

it('should group multi comparison plot path directories', () => {
const model = new PathsModel(
mockDvcRoot,
buildMockErrorsModel(),
buildMockMemento()
)
const currentOrder = [
join('plots', 'acc.png'),
join('plots', 'heatmap.png'),
join('plots', 'loss.png'),
join('plots', 'image')
]

model.transformAndSet(plotsDiffMultiImgFixture, REVISIONS)
model.setSelectedRevisions([EXPERIMENT_WORKSPACE_ID])

expect(model.getComparisonPaths()).toStrictEqual(currentOrder)

const newOrder = [join('plots', 'image'), ...currentOrder.slice(0, 3)]

model.setComparisonPathsOrder(newOrder)

expect(model.getComparisonPaths()).toStrictEqual(newOrder)
})

it('should return the expected children from the test fixture', () => {
const model = new PathsModel(
mockDvcRoot,
Expand Down
25 changes: 25 additions & 0 deletions extension/src/test/fixtures/plotsDiff/comparison/multiVscode.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { getComparisonWebviewMessage } from '..'
import { Uri, ViewColumn, window } from 'vscode'
import { ViewKey } from '../../../../webview/constants'
import { basePlotsUrl } from '../../../util'

const webviewPanel = window.createWebviewPanel(
ViewKey.PLOTS,
'webview for asWebviewUri',
ViewColumn.Active,
{
enableScripts: true
}
)

const baseUrl = webviewPanel.webview
.asWebviewUri(Uri.file(basePlotsUrl))
.toString()

webviewPanel.dispose()

const uriJoin = (...segments: string[]) => segments.join('/')

const data = getComparisonWebviewMessage(baseUrl, uriJoin, true)

export default data
10 changes: 6 additions & 4 deletions extension/src/test/fixtures/plotsDiff/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ const getMultiImageData = (
}[]
} = {}
for (let i = 0; i < 15; i++) {
const key = join('plots', 'image', `${i}.jpg`)
const key = joinFunc('plots', 'image', `${i}.jpg`)
const values = []
for (const revision of revisions) {
values.push({
Expand Down Expand Up @@ -496,6 +496,10 @@ export const getOutput = (baseUrl: string): PlotsOutput => ({

export const getMinimalOutput = (): PlotsOutput => ({ data: { ...basicVega } })

export const getMultiImgOutput = (baseUrl: string): PlotsOutput => ({
data: { ...getImageDataWithMultiImgs(baseUrl) }
})

export const getMultiSourceOutput = (): PlotsOutput => ({
...require('./multiSource').default
})
Expand Down Expand Up @@ -809,9 +813,7 @@ export const getComparisonWebviewMessage = (
? getImageDataWithMultiImgs(baseUrl, joinFunc)
: getImageData(baseUrl, joinFunc)
)) {
const multiImagePath = joinFunc('plots', 'image')
const isMulti = path.includes(multiImagePath)
const pathLabel = path
const pathLabel = path.includes('image') ? join('plots', 'image') : path
Copy link
Contributor Author

@julieg18 julieg18 Jul 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had to use join here so the tests would pass on windows (we rebuild the image path when collecting the plots).


if (!plotAcc[pathLabel]) {
plotAcc[pathLabel] = {
Expand Down
6 changes: 6 additions & 0 deletions extension/src/test/fixtures/plotsDiff/output/multiImage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import { basePlotsUrl } from '../../../util'
import { getMultiImgOutput } from '..'

const data = getMultiImgOutput(basePlotsUrl)

export default data
18 changes: 18 additions & 0 deletions extension/src/test/suite/plots/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import gitLogFixture from '../../fixtures/expShow/base/gitLog'
import rowOrderFixture from '../../fixtures/expShow/base/rowOrder'
import customPlotsFixture from '../../fixtures/expShow/base/customPlots'
import plotsDiffFixture from '../../fixtures/plotsDiff/output'
import multiImagePlotsDiffFixture from '../../fixtures/plotsDiff/output/multiImage'
import multiSourcePlotsDiffFixture from '../../fixtures/plotsDiff/multiSource'
import templatePlotsFixture from '../../fixtures/plotsDiff/template'
import comparisonPlotsFixture from '../../fixtures/plotsDiff/comparison/vscode'
import comparisonPlotsMultiImgFixture from '../../fixtures/plotsDiff/comparison/multiVscode'
import plotsRevisionsFixture from '../../fixtures/plotsDiff/revisions'
import {
bypassProcessManagerDebounce,
Expand Down Expand Up @@ -968,6 +970,22 @@ suite('Plots Test Suite', () => {
}
}).timeout(WEBVIEW_TEST_TIMEOUT)

it('should send the correct data to the webview for multi image plots', async () => {
const { plots, messageSpy, mockPlotsDiff } = await buildPlots({
disposer: disposable,
plotsDiff: multiImagePlotsDiffFixture
})

const webview = await plots.showWebview()
await webview.isReady()

expect(mockPlotsDiff).to.be.called

const { comparison: comparisonData } = getFirstArgOfLastCall(messageSpy)

expect(comparisonData).to.deep.equal(comparisonPlotsMultiImgFixture)
}).timeout(WEBVIEW_TEST_TIMEOUT)

it('should handle a toggle experiment message from the webview', async () => {
const { plots, experiments } = await buildPlots({
disposer: disposable,
Expand Down
Loading