Skip to content

Commit

Permalink
Enable running exp apply and exp branch against commits
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon committed May 8, 2023
1 parent fca9961 commit 7e2f8cb
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 25 deletions.
4 changes: 2 additions & 2 deletions extension/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -1163,12 +1163,12 @@
{
"command": "dvc.views.experiments.applyExperiment",
"group": "inline@1",
"when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem == experiment && !dvc.experiment.running.workspace"
"when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem =~ /^(experiment|commit)$/ && !dvc.experiment.running.workspace"
},
{
"command": "dvc.views.experiments.branchExperiment",
"group": "inline@2",
"when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem == experiment && !dvc.experiment.running.workspace"
"when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem =~ /^(experiment|commit)$/ && !dvc.experiment.running.workspace"
},
{
"command": "dvc.views.experimentsTree.removeExperiment",
Expand Down
4 changes: 2 additions & 2 deletions extension/src/experiments/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,9 @@ export class Experiments extends BaseRepository<TableData> {
return this.notifyChanged()
}

public pickExperiment() {
public pickCommitOrExperiment() {
return pickExperiment(
this.experiments.getExperiments(),
this.experiments.getCommitsAndExperiments(),
this.getFirstThreeColumnOrder()
)
}
Expand Down
32 changes: 31 additions & 1 deletion extension/src/experiments/model/collect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ import {
} from 'vscode'
import { ExperimentType } from '.'
import { extractColumns } from '../columns/extract'
import { Experiment, CommitData, RunningExperiment } from '../webview/contract'
import {
Experiment,
CommitData,
RunningExperiment,
isQueued
} from '../webview/contract'
import {
EXPERIMENT_WORKSPACE_ID,
ExperimentStatus,
Expand Down Expand Up @@ -372,3 +377,28 @@ export const collectExperimentType = (

return acc
}

const collectExperimentsAndCommit = (
acc: Experiment[],
commit: Experiment,
experiments: Experiment[] = []
): void => {
acc.push(commit)
for (const experiment of experiments) {
if (isQueued(experiment.status)) {
continue
}
acc.push(experiment)
}
}

export const collectOrderedCommitsAndExperiments = (
commits: Experiment[],
getExperimentsByCommit: (commit: Experiment) => Experiment[] | undefined
): Experiment[] => {
const acc: Experiment[] = []
for (const commit of commits) {
collectExperimentsAndCommit(acc, commit, getExperimentsByCommit(commit))
}
return acc
}
11 changes: 10 additions & 1 deletion extension/src/experiments/model/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import { Memento } from 'vscode'
import { SortDefinition, sortExperiments } from './sortBy'
import { FilterDefinition, filterExperiment, getFilterId } from './filterBy'
import { collectExperiments } from './collect'
import {
collectExperiments,
collectOrderedCommitsAndExperiments
} from './collect'
import {
collectColoredStatus,
collectFinishedRunningExperiments,
Expand Down Expand Up @@ -319,6 +322,12 @@ export class ExperimentsModel extends ModelWithPersistence {
})
}

public getCommitsAndExperiments() {
return collectOrderedCommitsAndExperiments(this.commits, commit =>
this.getExperimentsByCommit(commit)
)
}

public getExperimentsAndQueued() {
return flattenMapValues(this.experimentsByCommit).map(experiment =>
this.addDetails(experiment)
Expand Down
16 changes: 8 additions & 8 deletions extension/src/experiments/workspace.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const mockedQuickPickOne = jest.mocked(quickPickOne)
const mockedQuickPickManyValues = jest.mocked(quickPickManyValues)
const mockedQuickPickOneOrInput = jest.mocked(quickPickOneOrInput)
const mockedGetValidInput = jest.mocked(getValidInput)
const mockedPickExperiment = jest.fn()
const mockedPickCommitOrExperiment = jest.fn()
const mockedGetInput = jest.mocked(getInput)
const mockedRun = jest.fn()
const mockedExpFunc = jest.fn()
Expand Down Expand Up @@ -91,7 +91,7 @@ describe('Experiments', () => {
{
'/my/dvc/root': {
getDvcRoot: () => mockedDvcRoot,
pickExperiment: mockedPickExperiment,
pickCommitOrExperiment: mockedPickCommitOrExperiment,
showWebview: mockedShowWebview
} as unknown as Experiments,
'/my/fun/dvc/root': {
Expand Down Expand Up @@ -138,12 +138,12 @@ describe('Experiments', () => {
it('should call the correct function with the correct parameters if a project and experiment are picked', async () => {
mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot)
mockedListStages.mockResolvedValueOnce('train')
mockedPickExperiment.mockResolvedValueOnce('a123456')
mockedPickCommitOrExperiment.mockResolvedValueOnce('a123456')

await workspaceExperiments.getCwdAndExpNameThenRun(mockedCommandId)

expect(mockedQuickPickOne).toHaveBeenCalledTimes(1)
expect(mockedPickExperiment).toHaveBeenCalledTimes(1)
expect(mockedPickCommitOrExperiment).toHaveBeenCalledTimes(1)
expect(mockedExpFunc).toHaveBeenCalledTimes(1)
expect(mockedExpFunc).toHaveBeenCalledWith(mockedDvcRoot, 'a123456')
})
Expand Down Expand Up @@ -240,7 +240,7 @@ describe('Experiments', () => {
it('should call the correct function with the correct parameters if a project and experiment are picked and an input provided', async () => {
mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot)
mockedListStages.mockResolvedValueOnce('train')
mockedPickExperiment.mockResolvedValueOnce('a123456')
mockedPickCommitOrExperiment.mockResolvedValueOnce('a123456')
mockedGetInput.mockResolvedValueOnce('abc123')

await workspaceExperiments.getCwdExpNameAndInputThenRun(
Expand All @@ -250,7 +250,7 @@ describe('Experiments', () => {
)

expect(mockedQuickPickOne).toHaveBeenCalledTimes(1)
expect(mockedPickExperiment).toHaveBeenCalledTimes(1)
expect(mockedPickCommitOrExperiment).toHaveBeenCalledTimes(1)
expect(mockedExpFunc).toHaveBeenCalledTimes(1)
expect(mockedExpFunc).toHaveBeenCalledWith(
mockedDvcRoot,
Expand All @@ -276,7 +276,7 @@ describe('Experiments', () => {
it('should not call the function if user input is not provided', async () => {
mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot)
mockedListStages.mockResolvedValueOnce('train')
mockedPickExperiment.mockResolvedValueOnce({
mockedPickCommitOrExperiment.mockResolvedValueOnce({
id: 'b456789',
name: 'exp-456'
})
Expand All @@ -296,7 +296,7 @@ describe('Experiments', () => {
it('should check and ask for the creation of a pipeline stage before running the command', async () => {
mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot)
mockedListStages.mockResolvedValueOnce('')
mockedPickExperiment.mockResolvedValueOnce({
mockedPickCommitOrExperiment.mockResolvedValueOnce({
id: 'a123456',
name: 'exp-123'
})
Expand Down
10 changes: 6 additions & 4 deletions extension/src/experiments/workspace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews<
}

public getCwdAndExpNameThenRun(commandId: CommandId) {
return this.pickExpThenRun(commandId, cwd => this.pickExperiment(cwd))
return this.pickExpThenRun(commandId, cwd =>
this.pickCommitOrExperiment(cwd)
)
}

public async getCwdAndQuickPickThenRun(
Expand All @@ -237,7 +239,7 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews<
return
}

const experimentId = await this.pickExperiment(cwd)
const experimentId = await this.pickCommitOrExperiment(cwd)

if (!experimentId) {
return
Expand Down Expand Up @@ -545,7 +547,7 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews<
return this.runCommand(commandId, cwd, experimentId)
}

private pickExperiment(cwd: string) {
return this.getRepository(cwd).pickExperiment()
private pickCommitOrExperiment(cwd: string) {
return this.getRepository(cwd).pickCommitOrExperiment()
}
}
11 changes: 10 additions & 1 deletion extension/src/test/suite/experiments/workspace.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ suite('Workspace Experiments Test Suite', () => {
})

describe('dvc.applyExperiment', () => {
it('should ask the user to pick an experiment and then apply that experiment to the workspace', async () => {
it('should ask the user to pick a commit or experiment and then apply it to the workspace', async () => {
const selectedExperiment = 'test-branch'

const { experiments } = buildExperiments(disposable)
Expand All @@ -562,8 +562,17 @@ suite('Workspace Experiments Test Suite', () => {
dvcDemoPath,
selectedExperiment
)

expect(mockShowQuickPick).to.be.calledWith(
[
{
description: undefined,
detail: `Created:${formatDate(
'2020-11-21T19:58:22'
)}, loss:2.0488560, accuracy:0.34848332`,
label: 'main',
value: 'main'
},
{
description: '[exp-e7a67]',
detail: `Created:${formatDate(
Expand Down
24 changes: 23 additions & 1 deletion webview/src/experiments/components/App.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ describe('App', () => {
expect(itemLabels).toStrictEqual(['Modify and Run', 'Modify and Queue'])
})

it('should enable the correct options for the main row with checkpoints', () => {
it('should enable the correct options for a commit with checkpoints', () => {
renderTableWithoutRunningExperiments()

const target = screen.getByText('main')
Expand All @@ -870,13 +870,35 @@ describe('App', () => {
.filter(item => !item.className.includes('disabled'))
.map(item => item.textContent)
expect(itemLabels).toStrictEqual([
'Apply to Workspace',
'Create new Branch',
'Modify and Run',
'Modify and Resume',
'Modify and Queue',
'Star'
])
})

it('should enable the correct options for a commit without checkpoints', () => {
renderTableWithoutRunningExperiments(false)

const target = screen.getByText('main')
fireEvent.contextMenu(target, { bubbles: true })

advanceTimersByTime(100)
const menuitems = screen.getAllByRole('menuitem')
const itemLabels = menuitems
.filter(item => !item.className.includes('disabled'))
.map(item => item.textContent)
expect(itemLabels).toStrictEqual([
'Apply to Workspace',
'Create new Branch',
'Modify and Run',
'Modify and Queue',
'Star'
])
})

it('should enable the correct options for an experiment that is not running and close on esc', () => {
renderTableWithoutRunningExperiments()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ const getSingleSelectMenuOptions = (
divider
)

const disableIfRunningOrNotExperiment = (
const disableIfRunningOrWorkspace = (
label: string,
type: MessageFromWebviewType,
divider?: boolean
) => disableIfRunning(label, type, isNotExperiment, divider)
) => disableIfRunning(label, type, isWorkspace, divider)

return [
experimentMenuOption(
Expand All @@ -242,11 +242,11 @@ const getSingleSelectMenuOptions = (
MessageFromWebviewType.SHOW_EXPERIMENT_LOGS,
!isRunningInQueue({ executor, status })
),
disableIfRunningOrNotExperiment(
disableIfRunningOrWorkspace(
'Apply to Workspace',
MessageFromWebviewType.APPLY_EXPERIMENT_TO_WORKSPACE
),
disableIfRunningOrNotExperiment(
disableIfRunningOrWorkspace(
'Create new Branch',
MessageFromWebviewType.CREATE_BRANCH_FROM_EXPERIMENT
),
Expand Down
5 changes: 4 additions & 1 deletion webview/src/test/experimentsTable.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,12 @@ export const renderTableWithSortingData = () => {
return renderTable(sortingTableDataFixture)
}

export const renderTableWithoutRunningExperiments = () => {
export const renderTableWithoutRunningExperiments = (
hasCheckpoints?: boolean
) => {
renderTable({
...tableDataFixture,
hasCheckpoints: hasCheckpoints ?? tableDataFixture.hasCheckpoints,
hasRunningWorkspaceExperiment: false,
rows: tableDataFixture.rows.map(row => ({
...row,
Expand Down

0 comments on commit 7e2f8cb

Please sign in to comment.