diff --git a/extension/src/cli/dvc/constants.ts b/extension/src/cli/dvc/constants.ts index bb48c418d9..d913cb44e2 100644 --- a/extension/src/cli/dvc/constants.ts +++ b/extension/src/cli/dvc/constants.ts @@ -18,6 +18,8 @@ export const EXP_RWLOCK_FILE = join(TEMP_EXP_DIR, 'rwlock.lock') export const DEFAULT_NUM_OF_COMMITS_TO_SHOW = 3 export const NUM_OF_COMMITS_TO_INCREASE = 2 +export const MULTI_IMAGE_PATH_REG = /[^/]+[/\\]\d+\.[a-z]+$/i + export enum Command { ADD = 'add', CHECKOUT = 'checkout', diff --git a/extension/src/cli/dvc/index.test.ts b/extension/src/cli/dvc/index.test.ts index 55272fc296..d38f95f0d7 100644 --- a/extension/src/cli/dvc/index.test.ts +++ b/extension/src/cli/dvc/index.test.ts @@ -1,7 +1,8 @@ +import { join } from 'path' import { EventEmitter } from 'vscode' import { Disposable, Disposer } from '@hediet/std/disposable' import { DvcCli } from '.' -import { Command } from './constants' +import { Command, MULTI_IMAGE_PATH_REG } from './constants' import { CliResult, CliStarted, typeCheckCommands } from '..' import { getProcessEnv } from '../../env' import { createProcess } from '../../process/execution' @@ -52,6 +53,53 @@ describe('typeCheckCommands', () => { }) }) +describe('Comparison Multi Image Regex', () => { + it('should match a nested image group directory', () => { + expect( + MULTI_IMAGE_PATH_REG.test( + join( + 'extremely', + 'super', + 'super', + 'super', + 'nested', + 'image', + '768.svg' + ) + ) + ).toBe(true) + }) + + it('should match directories with spaces or special characters', () => { + expect(MULTI_IMAGE_PATH_REG.test(join('mis classified', '5.png'))).toBe( + true + ) + + expect(MULTI_IMAGE_PATH_REG.test(join('misclassified#^', '5.png'))).toBe( + true + ) + }) + + it('should match different types of images', () => { + const imageFormats = ['svg', 'png', 'jpg', 'jpeg'] + for (const format of imageFormats) { + expect( + MULTI_IMAGE_PATH_REG.test(join('misclassified', `5.${format}`)) + ).toBe(true) + } + }) + + it('should not match files that include none digits or do not have a file extension', () => { + expect(MULTI_IMAGE_PATH_REG.test(join('misclassified', 'five.png'))).toBe( + false + ) + expect(MULTI_IMAGE_PATH_REG.test(join('misclassified', '5 4.png'))).toBe( + false + ) + expect(MULTI_IMAGE_PATH_REG.test(join('misclassified', '5'))).toBe(false) + }) +}) + describe('executeDvcProcess', () => { it('should pass the correct details to the underlying process given no path to the cli or python binary path', async () => { const existingPath = joinEnvPath( diff --git a/extension/src/fileSystem/util.ts b/extension/src/fileSystem/util.ts index 2ff2d3b132..384422476d 100644 --- a/extension/src/fileSystem/util.ts +++ b/extension/src/fileSystem/util.ts @@ -1,4 +1,4 @@ -import { sep } from 'path' +import { sep, parse } from 'path' export const getPathArray = (path: string): string[] => path.split(sep) @@ -18,3 +18,5 @@ export const getParent = (pathArray: string[], idx: number) => { export const removeTrailingSlash = (path: string): string => path.endsWith(sep) ? path.slice(0, -1) : path + +export const getFileNameWithoutExt = (path: string) => parse(path).name diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index 2773c1bce1..0ac8afdb8b 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -84,7 +84,8 @@ describe('collectData', () => { expect(Object.keys(comparisonData.main)).toStrictEqual([ join('plots', 'acc.png'), heatmapPlot, - join('plots', 'loss.png') + join('plots', 'loss.png'), + join('plots', 'image') ]) const testBranchHeatmap = comparisonData['test-branch'][heatmapPlot] diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 7bbbbd59a7..67222bdfd5 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -36,6 +36,12 @@ import { import { StrokeDashEncoding } from '../multiSource/constants' import { exists } from '../../fileSystem' import { hasKey } from '../../util/object' +import { MULTI_IMAGE_PATH_REG } from '../../cli/dvc/constants' +import { + getFileNameWithoutExt, + getParent, + getPathArray +} from '../../fileSystem/util' export const getCustomPlotId = (metric: string, param: string) => `custom-${metric}-${param}` @@ -128,19 +134,31 @@ export type RevisionData = { [label: string]: RevisionPathData } +type ComparisonDataImgPlot = ImagePlot & { ind?: number } + export type ComparisonData = { [label: string]: { - [path: string]: ImagePlot[] + [path: string]: ComparisonDataImgPlot[] } } +const getMultiImagePath = (path: string) => + getParent(getPathArray(path), 0) as string + +const getMultiImageInd = (path: string) => { + const fileName = getFileNameWithoutExt(path) + return Number(fileName) +} + const collectImageData = ( acc: ComparisonData, path: string, plot: ImagePlot ) => { - const pathLabel = path + const isMultiImgPlot = MULTI_IMAGE_PATH_REG.test(path) + const pathLabel = isMultiImgPlot ? getMultiImagePath(path) : path const id = plot.revisions?.[0] + if (!id) { return } @@ -153,7 +171,13 @@ const collectImageData = ( acc[id][pathLabel] = [] } - acc[id][pathLabel].push(plot) + const imgPlot: ComparisonDataImgPlot = { ...plot } + + if (isMultiImgPlot) { + imgPlot.ind = getMultiImageInd(path) + } + + acc[id][pathLabel].push(imgPlot) } const collectDatapoints = ( @@ -209,6 +233,16 @@ const collectPathData = (acc: DataAccumulator, path: string, plots: Plot[]) => { } } +const sortComparisonImgPaths = (acc: DataAccumulator) => { + for (const [label, paths] of Object.entries(acc.comparisonData)) { + for (const path of Object.keys(paths)) { + acc.comparisonData[label][path].sort( + (img1, img2) => (img1.ind || 0) - (img2.ind || 0) + ) + } + } +} + export const collectData = (output: PlotsOutput): DataAccumulator => { const { data } = output const acc = { @@ -220,6 +254,8 @@ export const collectData = (output: PlotsOutput): DataAccumulator => { collectPathData(acc, path, plots) } + sortComparisonImgPaths(acc) + return acc } @@ -248,7 +284,7 @@ const collectSelectedPathComparisonPlots = ({ } for (const id of selectedRevisionIds) { - const imgs = comparisonData[id][path] + const imgs = comparisonData[id]?.[path] pathRevisions.revisions[id] = { id, imgs: imgs diff --git a/extension/src/plots/paths/collect.test.ts b/extension/src/plots/paths/collect.test.ts index 21cdfed603..d90320d68d 100644 --- a/extension/src/plots/paths/collect.test.ts +++ b/extension/src/plots/paths/collect.test.ts @@ -45,6 +45,13 @@ describe('collectPaths', () => { revisions: new Set(REVISIONS), type: new Set(['comparison']) }, + { + hasChildren: false, + parentPath: 'plots', + path: join('plots', 'image'), + revisions: new Set(REVISIONS), + type: new Set(['comparison']) + }, { hasChildren: false, parentPath: 'logs', diff --git a/extension/src/plots/paths/collect.ts b/extension/src/plots/paths/collect.ts index e75a46bd8a..6c3c6349e5 100644 --- a/extension/src/plots/paths/collect.ts +++ b/extension/src/plots/paths/collect.ts @@ -23,6 +23,7 @@ import { } from '../multiSource/constants' import { MultiSourceEncoding } from '../multiSource/collect' import { truncate } from '../../util/string' +import { MULTI_IMAGE_PATH_REG } from '../../cli/dvc/constants' export enum PathType { COMPARISON = 'comparison', @@ -58,8 +59,13 @@ const collectType = (plots: Plot[]) => { const getType = ( data: PlotsData, hasChildren: boolean, - path: string + path: string, + isMultiImgPlot?: boolean ): Set | undefined => { + if (isMultiImgPlot) { + return new Set([PathType.COMPARISON]) + } + if (hasChildren) { return } @@ -123,7 +129,8 @@ const collectOrderedPath = ( data: PlotsData, revisions: Set, pathArray: string[], - idx: number + idx: number, + isMultiImgDir: boolean ): PlotPath[] => { const path = getPath(pathArray, idx) @@ -147,7 +154,10 @@ const collectOrderedPath = ( revisions } - const type = getType(data, hasChildren, path) + const isPathLeaf = idx === pathArray.length + const isMultiImgPlot = isMultiImgDir && isPathLeaf + + const type = getType(data, hasChildren, path, isMultiImgPlot) if (type) { plotPath.type = type } @@ -167,9 +177,23 @@ const addRevisionsToPath = ( } const pathArray = getPathArray(path) + const isMultiImg = + MULTI_IMAGE_PATH_REG.test(path) && + !!getType(data, false, path)?.has(PathType.COMPARISON) + + if (isMultiImg) { + pathArray.pop() + } for (let reverseIdx = pathArray.length; reverseIdx > 0; reverseIdx--) { - acc = collectOrderedPath(acc, data, revisions, pathArray, reverseIdx) + acc = collectOrderedPath( + acc, + data, + revisions, + pathArray, + reverseIdx, + isMultiImg + ) } return acc } diff --git a/extension/src/plots/paths/model.test.ts b/extension/src/plots/paths/model.test.ts index fb83d660a7..caf5468886 100644 --- a/extension/src/plots/paths/model.test.ts +++ b/extension/src/plots/paths/model.test.ts @@ -59,6 +59,14 @@ describe('PathsModel', () => { selected: true, type: comparisonType }, + { + hasChildren: false, + parentPath: 'plots', + path: join('plots', 'image'), + revisions: new Set(REVISIONS), + selected: true, + type: comparisonType + }, { hasChildren: false, parentPath: 'logs', @@ -340,13 +348,15 @@ describe('PathsModel', () => { expect(model.getComparisonPaths()).toStrictEqual([ join('plots', 'acc.png'), join('plots', 'heatmap.png'), - join('plots', 'loss.png') + join('plots', 'loss.png'), + join('plots', 'image') ]) const newOrder = [ join('plots', 'heatmap.png'), join('plots', 'acc.png'), - join('plots', 'loss.png') + join('plots', 'loss.png'), + join('plots', 'image') ] model.setComparisonPathsOrder(newOrder) @@ -380,7 +390,7 @@ describe('PathsModel', () => { tooltip: undefined }, { - descendantStatuses: [2, 2, 2], + descendantStatuses: [2, 2, 2, 2], hasChildren: true, parentPath: undefined, path: 'plots', diff --git a/extension/src/test/fixtures/plotsDiff/comparison/multi.ts b/extension/src/test/fixtures/plotsDiff/comparison/multi.ts deleted file mode 100644 index 8c8a2bd1ab..0000000000 --- a/extension/src/test/fixtures/plotsDiff/comparison/multi.ts +++ /dev/null @@ -1,5 +0,0 @@ -import { getComparisonWebviewMessage } from '..' - -const data = getComparisonWebviewMessage('.', undefined, true) - -export default data diff --git a/extension/src/test/fixtures/plotsDiff/index.ts b/extension/src/test/fixtures/plotsDiff/index.ts index 476cac6d85..48041b3c0e 100644 --- a/extension/src/test/fixtures/plotsDiff/index.ts +++ b/extension/src/test/fixtures/plotsDiff/index.ts @@ -377,7 +377,7 @@ const getMultiImageData = ( }[] } = {} for (let i = 0; i < 15; i++) { - const key = join('plots', 'image', `${i}.jpg`) + const key = joinFunc('plots', 'image', `${i}.jpg`) const values = [] for (const revision of revisions) { values.push({ @@ -472,11 +472,7 @@ const getImageData = (baseUrl: string, joinFunc = join) => ({ revisions: ['exp-83425'], url: joinFunc(baseUrl, '1ba7bcd_plots_loss.png') } - ] -}) - -const getImageDataWithMultiImgs = (baseUrl: string, joinFunc = join) => ({ - ...getImageData(baseUrl, joinFunc), + ], ...getMultiImageData(baseUrl, joinFunc, [ EXPERIMENT_WORKSPACE_ID, 'main', @@ -797,21 +793,14 @@ export const MOCK_IMAGE_MTIME = 946684800000 export const getComparisonWebviewMessage = ( baseUrl: string, - joinFunc: (...args: string[]) => string = join, - addMulti?: boolean + joinFunc: (...args: string[]) => string = join ): PlotsComparisonData => { const plotAcc: { [path: string]: { path: string; revisions: ComparisonRevisionData } } = {} - for (const [path, plots] of Object.entries( - addMulti - ? getImageDataWithMultiImgs(baseUrl, joinFunc) - : getImageData(baseUrl, joinFunc) - )) { - const multiImagePath = joinFunc('plots', 'image') - const isMulti = path.includes(multiImagePath) - const pathLabel = path + for (const [path, plots] of Object.entries(getImageData(baseUrl, joinFunc))) { + const pathLabel = path.includes('image') ? join('plots', 'image') : path if (!plotAcc[pathLabel]) { plotAcc[pathLabel] = { diff --git a/extension/src/test/suite/plots/index.test.ts b/extension/src/test/suite/plots/index.test.ts index 826e7f2fdf..d4b5c6b135 100644 --- a/extension/src/test/suite/plots/index.test.ts +++ b/extension/src/test/suite/plots/index.test.ts @@ -318,7 +318,8 @@ suite('Plots Test Suite', () => { const mockComparisonPathsOrder = [ join('plots', 'acc.png'), join('plots', 'heatmap.png'), - join('plots', 'loss.png') + join('plots', 'loss.png'), + join('plots', 'image') ] messageSpy.resetHistory() diff --git a/webview/src/plots/components/App.test.tsx b/webview/src/plots/components/App.test.tsx index 0c97dc951d..63ae0df46b 100644 --- a/webview/src/plots/components/App.test.tsx +++ b/webview/src/plots/components/App.test.tsx @@ -275,6 +275,70 @@ describe('App', () => { expect(loading).toHaveLength(3) }) + it('should render loading section states with multi image plots when given a single revision which has not been fetched', async () => { + renderAppWithOptionalData({ + comparison: { + height: DEFAULT_PLOT_HEIGHT, + plots: [ + { + path: 'training/plots/images/image', + revisions: { + ad2b5ec: { + id: 'ad2b5ec', + imgs: [ + { + errors: undefined, + loading: true, + url: undefined + }, + { + errors: undefined, + loading: true, + url: undefined + }, + { + errors: undefined, + loading: true, + url: undefined + } + ] + } + } + } + ], + revisions: [ + { + description: '[exp-a270a]', + displayColor: '#945dd6', + fetched: false, + id: 'ad2b5ec', + label: 'ad2b5ec', + summaryColumns: [] + } + ], + width: DEFAULT_NB_ITEMS_PER_ROW + }, + custom: null, + hasPlots: true, + hasUnselectedPlots: false, + sectionCollapsed: DEFAULT_SECTION_COLLAPSED, + selectedRevisions: [ + { + description: '[exp-a270a]', + displayColor: '#945dd6', + fetched: false, + id: 'ad2b5ec', + label: 'ad2b5ec', + summaryColumns: [] + } + ], + template: null + }) + const loading = await screen.findAllByText('Loading...') + + expect(loading).toHaveLength(3) + }) + it('should render only get started (buttons: add plots, add experiments, add custom plots) when there are some selected exps, all unselected plots, and no custom plots', async () => { renderAppWithOptionalData({ hasPlots: true, @@ -1384,6 +1448,22 @@ describe('App', () => { }) }) + it('should send a message with the plot path when a comparison table multi img plot is zoomed', () => { + renderAppWithOptionalData({ + comparison: comparisonTableFixture + }) + + const plotWrapper = screen.getAllByTestId('multi-image-cell')[0] + const plot = within(plotWrapper).getByTestId('image-plot-button') + + fireEvent.click(plot) + + expect(mockPostMessage).toHaveBeenCalledWith({ + payload: comparisonTableFixture.plots[3].revisions.workspace.imgs[0].url, + type: MessageFromWebviewType.ZOOM_PLOT + }) + }) + it('should open a modal with the plot zoomed in when clicking a custom plot', () => { renderAppWithOptionalData({ custom: customPlotsFixture @@ -1570,6 +1650,58 @@ describe('App', () => { expect(multiViewPlot).toHaveStyle('--scale: 2') }) + describe('Comparison Multi Image Plots', () => { + it('should render cells with sliders', () => { + renderAppWithOptionalData({ + comparison: comparisonTableFixture + }) + + const multiImgPlot = screen.getAllByTestId('multi-image-cell')[0] + const slider = within(multiImgPlot).getByRole('slider') + + expect(slider).toBeInTheDocument() + }) + + it('should update the cell image when the slider changes', () => { + renderAppWithOptionalData({ + comparison: comparisonTableFixture + }) + + const workspaceImgs = + comparisonTableFixture.plots[3].revisions.workspace.imgs + + const multiImgPlots = screen.getAllByTestId('multi-image-cell') + const slider = within(multiImgPlots[0]).getByRole('slider') + const workspaceImgEl = within(multiImgPlots[0]).getByRole('img') + + expect(workspaceImgEl).toHaveAttribute('src', workspaceImgs[0].url) + + fireEvent.change(slider, { target: { value: 3 } }) + + expect(workspaceImgEl).toHaveAttribute('src', workspaceImgs[3].url) + }) + + it('should disable the multi img row from drag and drop when hovering over a img slider', () => { + renderAppWithOptionalData({ + comparison: comparisonTableFixture + }) + + const multiImgRow = screen.getAllByTestId('comparison-table-body')[3] + const multiImgPlots = screen.getAllByTestId('multi-image-cell') + const slider = within(multiImgPlots[0]).getByRole('slider') + + expect(multiImgRow.draggable).toBe(true) + + fireEvent.mouseEnter(slider) + + expect(multiImgRow.draggable).toBe(false) + + fireEvent.mouseLeave(slider) + + expect(multiImgRow.draggable).toBe(true) + }) + }) + describe('Virtualization', () => { const createCustomPlots = (nbOfPlots: number): CustomPlotsData => { const plots = [] diff --git a/webview/src/plots/components/comparisonTable/ComparisonTable.test.tsx b/webview/src/plots/components/comparisonTable/ComparisonTable.test.tsx index 5739ddc2e1..f037e564ed 100644 --- a/webview/src/plots/components/comparisonTable/ComparisonTable.test.tsx +++ b/webview/src/plots/components/comparisonTable/ComparisonTable.test.tsx @@ -293,8 +293,8 @@ describe('ComparisonTable', () => { description: undefined, displayColor: '#f56565', fetched: true, - id: 'noData', - label: revisionWithNoData, + id: revisionWithNoData, + label: 'noData', summaryColumns: [] } ] diff --git a/webview/src/plots/components/comparisonTable/ComparisonTableRow.tsx b/webview/src/plots/components/comparisonTable/ComparisonTableRow.tsx index f92be3e0b2..a2cf194fc0 100644 --- a/webview/src/plots/components/comparisonTable/ComparisonTableRow.tsx +++ b/webview/src/plots/components/comparisonTable/ComparisonTableRow.tsx @@ -4,6 +4,7 @@ import cx from 'classnames' import { useSelector } from 'react-redux' import styles from './styles.module.scss' import { ComparisonTableCell } from './cell/ComparisonTableCell' +import { ComparisonTableMultiCell } from './cell/ComparisonTableMultiCell' import { Icon } from '../../../shared/components/Icon' import { ChevronDown, ChevronRight } from '../../../shared/components/icons' import { PlotsState } from '../../store' @@ -76,7 +77,11 @@ export const ComparisonTableRow: React.FC = ({ data-testid="row-images" className={cx(styles.cell, { [styles.cellHidden]: !isShown })} > - + {plot.imgs.length > 1 ? ( + + ) : ( + + )} ) diff --git a/webview/src/plots/components/comparisonTable/ComparisonTableRows.tsx b/webview/src/plots/components/comparisonTable/ComparisonTableRows.tsx index aebc2bee21..93960b9e3b 100644 --- a/webview/src/plots/components/comparisonTable/ComparisonTableRows.tsx +++ b/webview/src/plots/components/comparisonTable/ComparisonTableRows.tsx @@ -1,12 +1,13 @@ import { ComparisonPlots } from 'dvc/src/plots/webview/contract' import React, { createRef, useEffect, useState } from 'react' -import { useDispatch } from 'react-redux' +import { useDispatch, useSelector } from 'react-redux' import { ComparisonTableColumn } from './ComparisonTableHead' import { ComparisonTableRow } from './ComparisonTableRow' import { changeRowHeight, DEFAULT_ROW_HEIGHT } from './comparisonTableSlice' import { RowDropTarget } from './RowDropTarget' import { DragDropContainer } from '../../../shared/components/dragDrop/DragDropContainer' import { reorderComparisonRows } from '../../util/messages' +import { PlotsState } from '../../store' interface ComparisonTableRowsProps { plots: ComparisonPlots @@ -22,6 +23,9 @@ export const ComparisionTableRows: React.FC = ({ const [rowsOrder, setRowsOrder] = useState([]) const dispatch = useDispatch() const firstRowRef = createRef() + const disabledDragPlotIds = useSelector( + (state: PlotsState) => state.comparison.disabledDragPlotIds + ) useEffect(() => { setRowsOrder(plots.map(({ path }) => path)) @@ -40,12 +44,19 @@ export const ComparisionTableRows: React.FC = ({ } const revs = plot.revisions return ( - + ({ - ...revs[column.id], - id: column.id + id: column.id, + imgs: revs[column.id]?.imgs || [ + { errors: undefined, loading: false, url: undefined } + ] }))} nbColumns={columns.length} pinnedColumn={pinnedColumn} @@ -67,6 +78,7 @@ export const ComparisionTableRows: React.FC = ({ group="comparison-table" dropTarget={} onLayoutChange={onLayoutChange} + disabledDropIds={disabledDragPlotIds} vertical /> ) diff --git a/webview/src/plots/components/comparisonTable/cell/ComparisonTableCell.tsx b/webview/src/plots/components/comparisonTable/cell/ComparisonTableCell.tsx index f9c8e1fde9..53bee7782e 100644 --- a/webview/src/plots/components/comparisonTable/cell/ComparisonTableCell.tsx +++ b/webview/src/plots/components/comparisonTable/cell/ComparisonTableCell.tsx @@ -8,10 +8,10 @@ import { zoomPlot } from '../../../util/messages' export const ComparisonTableCell: React.FC<{ path: string plot: ComparisonPlot -}> = ({ path, plot }) => { - const plotImg = plot.imgs - ? plot.imgs[0] - : { errors: undefined, loading: false, url: undefined } + imgAlt?: string +}> = ({ path, plot, imgAlt }) => { + const plotImg = plot.imgs[0] + const loading = plotImg.loading const missing = !loading && !plotImg.url @@ -33,7 +33,7 @@ export const ComparisonTableCell: React.FC<{ className={styles.image} draggable={false} src={plotImg.url} - alt={`Plot of ${path} (${plot.id})`} + alt={imgAlt || `Plot of ${path} (${plot.id})`} /> ) diff --git a/webview/src/plots/components/comparisonTable/cell/ComparisonTableMultiCell.tsx b/webview/src/plots/components/comparisonTable/cell/ComparisonTableMultiCell.tsx new file mode 100644 index 0000000000..5700b39505 --- /dev/null +++ b/webview/src/plots/components/comparisonTable/cell/ComparisonTableMultiCell.tsx @@ -0,0 +1,50 @@ +import React, { useCallback, useState } from 'react' +import { useDispatch } from 'react-redux' +import { ComparisonPlot } from 'dvc/src/plots/webview/contract' +import { ComparisonTableCell } from './ComparisonTableCell' +import styles from '../styles.module.scss' +import { changeDisabledDragIds } from '../comparisonTableSlice' + +export const ComparisonTableMultiCell: React.FC<{ + path: string + plot: ComparisonPlot +}> = ({ path, plot }) => { + const [currentStep, setCurrentStep] = useState(0) + const dispatch = useDispatch() + + const addDisabled = useCallback(() => { + dispatch(changeDisabledDragIds([path])) + }, [dispatch, path]) + + const removeDisabled = useCallback(() => { + dispatch(changeDisabledDragIds([])) + }, [dispatch]) + + return ( +
+ +
+ + { + setCurrentStep(Number(event.target.value)) + }} + /> +

{currentStep}

+
+
+ ) +} diff --git a/webview/src/plots/components/comparisonTable/comparisonTableSlice.ts b/webview/src/plots/components/comparisonTable/comparisonTableSlice.ts index d9f9a3a1ed..81029d8710 100644 --- a/webview/src/plots/components/comparisonTable/comparisonTableSlice.ts +++ b/webview/src/plots/components/comparisonTable/comparisonTableSlice.ts @@ -11,11 +11,13 @@ export interface ComparisonTableState extends PlotsComparisonData { isCollapsed: boolean hasData: boolean rowHeight: number + disabledDragPlotIds: string[] } export const DEFAULT_ROW_HEIGHT = 200 export const comparisonTableInitialState: ComparisonTableState = { + disabledDragPlotIds: [], hasData: false, height: DEFAULT_HEIGHT[PlotsSection.COMPARISON_TABLE], isCollapsed: DEFAULT_SECTION_COLLAPSED[PlotsSection.COMPARISON_TABLE], @@ -30,6 +32,9 @@ export const comparisonTableSlice = createSlice({ initialState: comparisonTableInitialState, name: 'comparison', reducers: { + changeDisabledDragIds: (state, action: PayloadAction) => { + state.disabledDragPlotIds = action.payload + }, changeRowHeight: (state, action: PayloadAction) => { state.rowHeight = action.payload }, @@ -55,7 +60,12 @@ export const comparisonTableSlice = createSlice({ } }) -export const { update, setCollapsed, changeSize, changeRowHeight } = - comparisonTableSlice.actions +export const { + update, + setCollapsed, + changeSize, + changeDisabledDragIds, + changeRowHeight +} = comparisonTableSlice.actions export default comparisonTableSlice.reducer diff --git a/webview/src/plots/components/comparisonTable/styles.module.scss b/webview/src/plots/components/comparisonTable/styles.module.scss index 9784d666aa..12ef4ac796 100644 --- a/webview/src/plots/components/comparisonTable/styles.module.scss +++ b/webview/src/plots/components/comparisonTable/styles.module.scss @@ -88,25 +88,6 @@ $gap: 4px; transform: rotate(0deg); } -.imageWrapper { - width: 100%; - display: block; - padding: 0; - border: 0; -} - -.noImage { - background-color: $bg-color; - border-style: solid; - border-width: thin; - border-color: $bg-color; -} - -.noImageContent { - padding-top: 25%; - padding-bottom: 25%; -} - .rowToggler { border: none; background: none; @@ -170,12 +151,46 @@ $gap: 4px; max-height: 0; } -.cell img { +.image { width: 100%; height: auto; vertical-align: middle; } +.imageWrapper { + width: 100%; + display: block; + padding: 0; + border: 0; +} + +.noImageContent { + padding-top: 25%; + padding-bottom: 25%; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; +} + +.multiImageWrapper { + .image, + .noImageContent { + height: 380px; + object-fit: contain; + } +} + +.multiImageSlider { + display: flex; + height: 40px; + align-items: center; + justify-content: center; + color: $fg-color; + background-color: $bg-color; + box-shadow: inset 40px 40px $fg-transparency-1; +} + .experimentName { color: $meta-cell-color; } diff --git a/webview/src/stories/ComparisonTable.stories.tsx b/webview/src/stories/ComparisonTable.stories.tsx index 172607e779..cef1732ad2 100644 --- a/webview/src/stories/ComparisonTable.stories.tsx +++ b/webview/src/stories/ComparisonTable.stories.tsx @@ -5,13 +5,13 @@ import { userEvent, within } from '@storybook/testing-library' import React from 'react' import { Provider, useDispatch } from 'react-redux' import { + ComparisonPlotImg, ComparisonRevisionData, DEFAULT_NB_ITEMS_PER_ROW, DEFAULT_PLOT_HEIGHT, PlotsComparisonData } from 'dvc/src/plots/webview/contract' import comparisonTableFixture from 'dvc/src/test/fixtures/plotsDiff/comparison' -import comparisonTableMultiFixture from 'dvc/src/test/fixtures/plotsDiff/comparison/multi' import { EXPERIMENT_WORKSPACE_ID } from 'dvc/src/cli/dvc/contract' import { DISABLE_CHROMATIC_SNAPSHOTS } from './util' import { ComparisonTable } from '../plots/components/comparisonTable/ComparisonTable' @@ -83,18 +83,25 @@ const removeImages = ( ['main', '4fb124a'].includes(id)) || id === EXPERIMENT_WORKSPACE_ID ) { + const isMulti = revisionsData[id].imgs.length > 1 filteredRevisionData[id] = { id, - imgs: [ - { - errors: - id === 'main' - ? [`FileNotFoundError: ${path} not found.`] - : undefined, - loading: false, - url: undefined - } - ] + imgs: isMulti + ? (Array.from({ length: revisionsData[id].imgs.length }).fill({ + errors: undefined, + loading: false, + url: undefined + }) as ComparisonPlotImg[]) + : [ + { + errors: + id === 'main' + ? [`FileNotFoundError: ${path} not found.`] + : undefined, + loading: false, + url: undefined + } + ] } continue } @@ -103,11 +110,6 @@ const removeImages = ( return filteredRevisionData } -export const WithMultiImages = Template.bind({}) -WithMultiImages.args = { - plots: comparisonTableMultiFixture.plots -} - export const WithMissingData = Template.bind({}) WithMissingData.args = { plots: comparisonTableFixture.plots.map(({ path, revisions }) => ({