Skip to content

Commit

Permalink
Add default value for branch name when creating from experiment (#4037)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon authored Jun 5, 2023
1 parent c13f9a6 commit 9027045
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 51 deletions.
7 changes: 2 additions & 5 deletions extension/src/experiments/commands/register.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,7 @@ const registerExperimentInputCommands = (
): void => {
internalCommands.registerExternalCliCommand(
RegisteredCliCommands.EXPERIMENT_BRANCH,
() =>
experiments.getCwdExpNameAndInputThenRun(
getBranchExperimentCommand(experiments),
Title.ENTER_BRANCH_NAME
)
() => experiments.createExperimentBranch()
)

internalCommands.registerExternalCliCommand(
Expand All @@ -142,6 +138,7 @@ const registerExperimentInputCommands = (
experiments.getInputAndRun(
getBranchExperimentCommand(experiments),
Title.ENTER_BRANCH_NAME,
`${id}-branch`,
dvcRoot,
id
)
Expand Down
65 changes: 26 additions & 39 deletions extension/src/experiments/workspace.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import { buildMockMemento } from '../test/util'
import { buildMockedEventEmitter } from '../test/util/jest'
import { OutputChannel } from '../vscode/outputChannel'
import { Title } from '../vscode/title'
import { Args } from '../cli/dvc/constants'
import {
findOrCreateDvcYamlFile,
getFileExtension,
Expand Down Expand Up @@ -44,6 +43,7 @@ const mockedHasDvcYamlFile = jest.mocked(hasDvcYamlFile)
const mockedGetBranches = jest.fn()
const mockedGetCurrentBranch = jest.fn()
const mockedPickFile = jest.mocked(pickFile)
const mockedExpBranch = jest.fn()

jest.mock('vscode')
jest.mock('@hediet/std/disposable')
Expand Down Expand Up @@ -91,6 +91,11 @@ describe('Experiments', () => {
() => mockedGetCurrentBranch()
)

mockedInternalCommands.registerCommand(
AvailableCommands.EXP_BRANCH,
mockedExpBranch
)

const workspaceExperiments = new WorkspaceExperiments(
mockedInternalCommands,
buildMockMemento(),
Expand Down Expand Up @@ -242,79 +247,61 @@ describe('Experiments', () => {
})
})

describe('getCwdExpNameAndInputThenRun', () => {
it('should call the correct function with the correct parameters if a project and experiment are picked and an input provided', async () => {
describe('createExperimentBranch', () => {
it('should create a branch with the correct name if a project and experiment are picked and an input provided', async () => {
mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot)
mockedListStages.mockResolvedValueOnce('train')
mockedPickCommitOrExperiment.mockResolvedValueOnce('a123456')
mockedGetInput.mockResolvedValueOnce('abc123')

await workspaceExperiments.getCwdExpNameAndInputThenRun(
(cwd: string, ...args: Args) =>
workspaceExperiments.runCommand(mockedCommandId, cwd, ...args),
'enter your password please' as Title
)
await workspaceExperiments.createExperimentBranch()

expect(mockedQuickPickOne).toHaveBeenCalledTimes(1)
expect(mockedGetInput).toHaveBeenCalledWith(
Title.ENTER_BRANCH_NAME,
'a123456-branch'
)
expect(mockedPickCommitOrExperiment).toHaveBeenCalledTimes(1)
expect(mockedExpFunc).toHaveBeenCalledTimes(1)
expect(mockedExpFunc).toHaveBeenCalledWith(
expect(mockedExpBranch).toHaveBeenCalledTimes(1)
expect(mockedExpBranch).toHaveBeenCalledWith(
mockedDvcRoot,
'a123456',
'abc123'
)
})

it('should not call the function or ask for input if a project is not picked', async () => {
it('should not ask for a branch name if a project is not picked', async () => {
mockedQuickPickOne.mockResolvedValueOnce(undefined)

await workspaceExperiments.getCwdExpNameAndInputThenRun(
(cwd: string, ...args: Args) =>
workspaceExperiments.runCommand(mockedCommandId, cwd, ...args),
'please name the branch' as Title
)
await workspaceExperiments.createExperimentBranch()

expect(mockedQuickPickOne).toHaveBeenCalledTimes(1)
expect(mockedGetInput).not.toHaveBeenCalled()
expect(mockedExpFunc).not.toHaveBeenCalled()
expect(mockedExpBranch).not.toHaveBeenCalled()
})

it('should not call the function if user input is not provided', async () => {
it('should not create a branch if the user does not provide a name', async () => {
mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot)
mockedListStages.mockResolvedValueOnce('train')
mockedPickCommitOrExperiment.mockResolvedValueOnce({
id: 'b456789',
name: 'exp-456'
})
mockedPickCommitOrExperiment.mockResolvedValueOnce('exp-456')
mockedGetInput.mockResolvedValueOnce(undefined)

await workspaceExperiments.getCwdExpNameAndInputThenRun(
(cwd: string, ...args: Args) =>
workspaceExperiments.runCommand(mockedCommandId, cwd, ...args),
'please enter your bank account number and sort code' as Title
)
await workspaceExperiments.createExperimentBranch()

expect(mockedQuickPickOne).toHaveBeenCalledTimes(1)
expect(mockedGetInput).toHaveBeenCalledTimes(1)
expect(mockedExpFunc).not.toHaveBeenCalled()
expect(mockedExpBranch).not.toHaveBeenCalled()
})

it('should check and ask for the creation of a pipeline stage before running the command', async () => {
it('should check and ask for the creation of a pipeline stage before doing anything else', async () => {
mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot)
mockedListStages.mockResolvedValueOnce('')
mockedPickCommitOrExperiment.mockResolvedValueOnce({
id: 'a123456',
name: 'exp-123'
})
mockedPickCommitOrExperiment.mockResolvedValueOnce('exp-123')
mockedGetInput.mockResolvedValueOnce('abc123')

await workspaceExperiments.getCwdExpNameAndInputThenRun(
(cwd: string, ...args: Args) =>
workspaceExperiments.runCommand(mockedCommandId, cwd, ...args),
'enter your password please' as Title
)
await workspaceExperiments.createExperimentBranch()

expect(mockedExpFunc).not.toHaveBeenCalled()
expect(mockedExpBranch).not.toHaveBeenCalled()
})
})

Expand Down
21 changes: 14 additions & 7 deletions extension/src/experiments/workspace.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import { EventEmitter, Memento } from 'vscode'
import isEmpty from 'lodash.isempty'
import { Experiments, ModifiedExperimentAndRunCommandId } from '.'
import { getPushExperimentCommand } from './commands'
import {
getBranchExperimentCommand,
getPushExperimentCommand
} from './commands'
import { TableData } from './webview/contract'
import { Args } from '../cli/dvc/constants'
import {
Expand Down Expand Up @@ -248,10 +251,7 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews<
}
}

public async getCwdExpNameAndInputThenRun(
runCommand: (cwd: string, ...args: Args) => Promise<void>,
title: Title
) {
public async createExperimentBranch() {
const cwd = await this.shouldRun()
if (!cwd) {
return
Expand All @@ -262,15 +262,22 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews<
if (!experimentId) {
return
}
return this.getInputAndRun(runCommand, title, cwd, experimentId)
return this.getInputAndRun(
getBranchExperimentCommand(this),
Title.ENTER_BRANCH_NAME,
`${experimentId}-branch`,
cwd,
experimentId
)
}

public async getInputAndRun(
runCommand: (...args: Args) => Promise<void> | void,
title: Title,
defaultValue: string,
...args: Args
) {
const input = await getInput(title)
const input = await getInput(title, defaultValue)
if (!input) {
return
}
Expand Down

0 comments on commit 9027045

Please sign in to comment.