diff --git a/extension/package.json b/extension/package.json index ea7a7ef21a..8aca9e2ed2 100644 --- a/extension/package.json +++ b/extension/package.json @@ -1163,12 +1163,12 @@ { "command": "dvc.views.experiments.applyExperiment", "group": "inline@1", - "when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem == experiment && !dvc.experiment.running.workspace" + "when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem =~ /^(experiment|commit)$/ && !dvc.experiment.running.workspace" }, { "command": "dvc.views.experiments.branchExperiment", "group": "inline@2", - "when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem == experiment && !dvc.experiment.running.workspace" + "when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem =~ /^(experiment|commit)$/ && !dvc.experiment.running.workspace" }, { "command": "dvc.views.experimentsTree.removeExperiment", diff --git a/extension/src/experiments/index.ts b/extension/src/experiments/index.ts index eabc78efe1..6d5c943383 100644 --- a/extension/src/experiments/index.ts +++ b/extension/src/experiments/index.ts @@ -371,9 +371,9 @@ export class Experiments extends BaseRepository { return this.notifyChanged() } - public pickExperiment() { + public pickCommitOrExperiment() { return pickExperiment( - this.experiments.getExperiments(), + this.experiments.getCommitsAndExperiments(), this.getFirstThreeColumnOrder() ) } diff --git a/extension/src/experiments/model/collect.ts b/extension/src/experiments/model/collect.ts index 94985d1186..16155ed90b 100644 --- a/extension/src/experiments/model/collect.ts +++ b/extension/src/experiments/model/collect.ts @@ -6,7 +6,12 @@ import { } from 'vscode' import { ExperimentType } from '.' import { extractColumns } from '../columns/extract' -import { Experiment, CommitData, RunningExperiment } from '../webview/contract' +import { + Experiment, + CommitData, + RunningExperiment, + isQueued +} from '../webview/contract' import { EXPERIMENT_WORKSPACE_ID, ExperimentStatus, @@ -372,3 +377,28 @@ export const collectExperimentType = ( return acc } + +const collectExperimentsAndCommit = ( + acc: Experiment[], + commit: Experiment, + experiments: Experiment[] = [] +): void => { + acc.push(commit) + for (const experiment of experiments) { + if (isQueued(experiment.status)) { + continue + } + acc.push(experiment) + } +} + +export const collectOrderedCommitsAndExperiments = ( + commits: Experiment[], + getExperimentsByCommit: (commit: Experiment) => Experiment[] | undefined +): Experiment[] => { + const acc: Experiment[] = [] + for (const commit of commits) { + collectExperimentsAndCommit(acc, commit, getExperimentsByCommit(commit)) + } + return acc +} diff --git a/extension/src/experiments/model/index.ts b/extension/src/experiments/model/index.ts index 708759fc3c..26a40e35ba 100644 --- a/extension/src/experiments/model/index.ts +++ b/extension/src/experiments/model/index.ts @@ -1,7 +1,10 @@ import { Memento } from 'vscode' import { SortDefinition, sortExperiments } from './sortBy' import { FilterDefinition, filterExperiment, getFilterId } from './filterBy' -import { collectExperiments } from './collect' +import { + collectExperiments, + collectOrderedCommitsAndExperiments +} from './collect' import { collectColoredStatus, collectFinishedRunningExperiments, @@ -319,6 +322,12 @@ export class ExperimentsModel extends ModelWithPersistence { }) } + public getCommitsAndExperiments() { + return collectOrderedCommitsAndExperiments(this.commits, commit => + this.getExperimentsByCommit(commit) + ) + } + public getExperimentsAndQueued() { return flattenMapValues(this.experimentsByCommit).map(experiment => this.addDetails(experiment) diff --git a/extension/src/experiments/workspace.test.ts b/extension/src/experiments/workspace.test.ts index 68c598e6de..2dc31732bf 100644 --- a/extension/src/experiments/workspace.test.ts +++ b/extension/src/experiments/workspace.test.ts @@ -33,7 +33,7 @@ const mockedQuickPickOne = jest.mocked(quickPickOne) const mockedQuickPickManyValues = jest.mocked(quickPickManyValues) const mockedQuickPickOneOrInput = jest.mocked(quickPickOneOrInput) const mockedGetValidInput = jest.mocked(getValidInput) -const mockedPickExperiment = jest.fn() +const mockedPickCommitOrExperiment = jest.fn() const mockedGetInput = jest.mocked(getInput) const mockedRun = jest.fn() const mockedExpFunc = jest.fn() @@ -91,7 +91,7 @@ describe('Experiments', () => { { '/my/dvc/root': { getDvcRoot: () => mockedDvcRoot, - pickExperiment: mockedPickExperiment, + pickCommitOrExperiment: mockedPickCommitOrExperiment, showWebview: mockedShowWebview } as unknown as Experiments, '/my/fun/dvc/root': { @@ -138,12 +138,12 @@ describe('Experiments', () => { it('should call the correct function with the correct parameters if a project and experiment are picked', async () => { mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot) mockedListStages.mockResolvedValueOnce('train') - mockedPickExperiment.mockResolvedValueOnce('a123456') + mockedPickCommitOrExperiment.mockResolvedValueOnce('a123456') await workspaceExperiments.getCwdAndExpNameThenRun(mockedCommandId) expect(mockedQuickPickOne).toHaveBeenCalledTimes(1) - expect(mockedPickExperiment).toHaveBeenCalledTimes(1) + expect(mockedPickCommitOrExperiment).toHaveBeenCalledTimes(1) expect(mockedExpFunc).toHaveBeenCalledTimes(1) expect(mockedExpFunc).toHaveBeenCalledWith(mockedDvcRoot, 'a123456') }) @@ -240,7 +240,7 @@ describe('Experiments', () => { it('should call the correct function with the correct parameters if a project and experiment are picked and an input provided', async () => { mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot) mockedListStages.mockResolvedValueOnce('train') - mockedPickExperiment.mockResolvedValueOnce('a123456') + mockedPickCommitOrExperiment.mockResolvedValueOnce('a123456') mockedGetInput.mockResolvedValueOnce('abc123') await workspaceExperiments.getCwdExpNameAndInputThenRun( @@ -250,7 +250,7 @@ describe('Experiments', () => { ) expect(mockedQuickPickOne).toHaveBeenCalledTimes(1) - expect(mockedPickExperiment).toHaveBeenCalledTimes(1) + expect(mockedPickCommitOrExperiment).toHaveBeenCalledTimes(1) expect(mockedExpFunc).toHaveBeenCalledTimes(1) expect(mockedExpFunc).toHaveBeenCalledWith( mockedDvcRoot, @@ -276,7 +276,7 @@ describe('Experiments', () => { it('should not call the function if user input is not provided', async () => { mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot) mockedListStages.mockResolvedValueOnce('train') - mockedPickExperiment.mockResolvedValueOnce({ + mockedPickCommitOrExperiment.mockResolvedValueOnce({ id: 'b456789', name: 'exp-456' }) @@ -296,7 +296,7 @@ describe('Experiments', () => { it('should check and ask for the creation of a pipeline stage before running the command', async () => { mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot) mockedListStages.mockResolvedValueOnce('') - mockedPickExperiment.mockResolvedValueOnce({ + mockedPickCommitOrExperiment.mockResolvedValueOnce({ id: 'a123456', name: 'exp-123' }) diff --git a/extension/src/experiments/workspace.ts b/extension/src/experiments/workspace.ts index 72eec84ea0..6d4758e225 100644 --- a/extension/src/experiments/workspace.ts +++ b/extension/src/experiments/workspace.ts @@ -210,7 +210,9 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews< } public getCwdAndExpNameThenRun(commandId: CommandId) { - return this.pickExpThenRun(commandId, cwd => this.pickExperiment(cwd)) + return this.pickExpThenRun(commandId, cwd => + this.pickCommitOrExperiment(cwd) + ) } public async getCwdAndQuickPickThenRun( @@ -237,7 +239,7 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews< return } - const experimentId = await this.pickExperiment(cwd) + const experimentId = await this.pickCommitOrExperiment(cwd) if (!experimentId) { return @@ -545,7 +547,7 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews< return this.runCommand(commandId, cwd, experimentId) } - private pickExperiment(cwd: string) { - return this.getRepository(cwd).pickExperiment() + private pickCommitOrExperiment(cwd: string) { + return this.getRepository(cwd).pickCommitOrExperiment() } } diff --git a/extension/src/test/suite/experiments/workspace.test.ts b/extension/src/test/suite/experiments/workspace.test.ts index 47cb36997f..df51617bd5 100644 --- a/extension/src/test/suite/experiments/workspace.test.ts +++ b/extension/src/test/suite/experiments/workspace.test.ts @@ -542,7 +542,7 @@ suite('Workspace Experiments Test Suite', () => { }) describe('dvc.applyExperiment', () => { - it('should ask the user to pick an experiment and then apply that experiment to the workspace', async () => { + it('should ask the user to pick a commit or experiment and then apply it to the workspace', async () => { const selectedExperiment = 'test-branch' const { experiments } = buildExperiments(disposable) @@ -562,8 +562,17 @@ suite('Workspace Experiments Test Suite', () => { dvcDemoPath, selectedExperiment ) + expect(mockShowQuickPick).to.be.calledWith( [ + { + description: undefined, + detail: `Created:${formatDate( + '2020-11-21T19:58:22' + )}, loss:2.0488560, accuracy:0.34848332`, + label: 'main', + value: 'main' + }, { description: '[exp-e7a67]', detail: `Created:${formatDate( diff --git a/webview/src/experiments/components/App.test.tsx b/webview/src/experiments/components/App.test.tsx index db6c001ccd..8578763c9c 100644 --- a/webview/src/experiments/components/App.test.tsx +++ b/webview/src/experiments/components/App.test.tsx @@ -858,7 +858,7 @@ describe('App', () => { expect(itemLabels).toStrictEqual(['Modify and Run', 'Modify and Queue']) }) - it('should enable the correct options for the main row with checkpoints', () => { + it('should enable the correct options for a commit with checkpoints', () => { renderTableWithoutRunningExperiments() const target = screen.getByText('main') @@ -870,6 +870,8 @@ describe('App', () => { .filter(item => !item.className.includes('disabled')) .map(item => item.textContent) expect(itemLabels).toStrictEqual([ + 'Apply to Workspace', + 'Create new Branch', 'Modify and Run', 'Modify and Resume', 'Modify and Queue', @@ -877,6 +879,26 @@ describe('App', () => { ]) }) + it('should enable the correct options for a commit without checkpoints', () => { + renderTableWithoutRunningExperiments(false) + + const target = screen.getByText('main') + fireEvent.contextMenu(target, { bubbles: true }) + + advanceTimersByTime(100) + const menuitems = screen.getAllByRole('menuitem') + const itemLabels = menuitems + .filter(item => !item.className.includes('disabled')) + .map(item => item.textContent) + expect(itemLabels).toStrictEqual([ + 'Apply to Workspace', + 'Create new Branch', + 'Modify and Run', + 'Modify and Queue', + 'Star' + ]) + }) + it('should enable the correct options for an experiment that is not running and close on esc', () => { renderTableWithoutRunningExperiments() diff --git a/webview/src/experiments/components/table/body/RowContextMenu.tsx b/webview/src/experiments/components/table/body/RowContextMenu.tsx index 2d05fee05c..ae7e4cd751 100644 --- a/webview/src/experiments/components/table/body/RowContextMenu.tsx +++ b/webview/src/experiments/components/table/body/RowContextMenu.tsx @@ -229,11 +229,11 @@ const getSingleSelectMenuOptions = ( divider ) - const disableIfRunningOrNotExperiment = ( + const disableIfRunningOrWorkspace = ( label: string, type: MessageFromWebviewType, divider?: boolean - ) => disableIfRunning(label, type, isNotExperiment, divider) + ) => disableIfRunning(label, type, isWorkspace, divider) return [ experimentMenuOption( @@ -242,11 +242,11 @@ const getSingleSelectMenuOptions = ( MessageFromWebviewType.SHOW_EXPERIMENT_LOGS, !isRunningInQueue({ executor, status }) ), - disableIfRunningOrNotExperiment( + disableIfRunningOrWorkspace( 'Apply to Workspace', MessageFromWebviewType.APPLY_EXPERIMENT_TO_WORKSPACE ), - disableIfRunningOrNotExperiment( + disableIfRunningOrWorkspace( 'Create new Branch', MessageFromWebviewType.CREATE_BRANCH_FROM_EXPERIMENT ), diff --git a/webview/src/test/experimentsTable.tsx b/webview/src/test/experimentsTable.tsx index 7cac3c4394..458e85e854 100644 --- a/webview/src/test/experimentsTable.tsx +++ b/webview/src/test/experimentsTable.tsx @@ -60,9 +60,12 @@ export const renderTableWithSortingData = () => { return renderTable(sortingTableDataFixture) } -export const renderTableWithoutRunningExperiments = () => { +export const renderTableWithoutRunningExperiments = ( + hasCheckpoints?: boolean +) => { renderTable({ ...tableDataFixture, + hasCheckpoints: hasCheckpoints ?? tableDataFixture.hasCheckpoints, hasRunningWorkspaceExperiment: false, rows: tableDataFixture.rows.map(row => ({ ...row,