Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accomodate params that are lists #1818

Merged
merged 5 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion extension/src/cli/reader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ export type StatusesOrAlwaysChanged = StageOrFileStatuses | 'always changed'

export type StatusOutput = Record<string, StatusesOrAlwaysChanged[]>

export type Value = string | number | boolean | null
export type Value =
| string
| number
| boolean
| null
| number[]
| string[]
| boolean[]

export interface ValueTreeOrError {
data?: ValueTree
Expand Down
1 change: 1 addition & 0 deletions extension/src/experiments/columns/collect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ describe('collectColumns', () => {
joinColumnPath(ColumnType.METRICS, 'summary.json', 'val_loss'),
joinColumnPath(ColumnType.METRICS, 'summary.json', 'val_accuracy'),
joinColumnPath(ColumnType.PARAMS, 'params.yaml'),
joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'code_names'),
joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'epochs'),
joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'learning_rate'),
joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'dvc_logs_dir'),
Expand Down
3 changes: 3 additions & 0 deletions extension/src/experiments/columns/collect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ const getValueType = (value: Value) => {
if (value === null) {
return 'null'
}
if (Array.isArray(value)) {
return 'array'
}
return typeof value
}

Expand Down
8 changes: 8 additions & 0 deletions extension/src/experiments/columns/tree.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ describe('ExperimentsColumnsTree', () => {
path: paramsPath
})
expect(grandChildren).toStrictEqual([
{
collapsibleState: 0,
description: undefined,
dvcRoot: mockedDvcRoot,
iconPath: mockedSelectedCheckbox,
label: 'code_names',
path: appendColumnToPath(paramsPath, 'code_names')
},
{
collapsibleState: 0,
description: undefined,
Expand Down
2 changes: 1 addition & 1 deletion extension/src/experiments/columns/walk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const walkValueTree = (
ancestors: string[] = []
) => {
for (const [key, value] of Object.entries(tree)) {
if (value && typeof value === 'object') {
if (value && !Array.isArray(value) && typeof value === 'object') {
Copy link
Member Author

@mattseddon mattseddon Jun 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typeof [0,1] === 'object' which meant that we then split the array into different params or metrics and got the error shown in #1814 (everywhere, table was even broken).

walkValueTree(value, meta, onValue, [...ancestors, key])
} else {
onValue(key, value, meta, ancestors)
Expand Down
1 change: 1 addition & 0 deletions extension/src/experiments/model/queue/collect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ describe('collectFlatExperimentParams', () => {
it('should flatten the params into an array', () => {
const params = collectFlatExperimentParams(rowsFixture[0].params)
expect(params).toStrictEqual([
{ path: appendColumnToPath('params.yaml', 'code_names'), value: [0, 1] },
{ path: appendColumnToPath('params.yaml', 'epochs'), value: 2 },
{
path: appendColumnToPath('params.yaml', 'learning_rate'),
Expand Down
6 changes: 3 additions & 3 deletions extension/src/experiments/model/queue/collect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ import { Columns } from '../../webview/contract'

export type Param = {
path: string
value: number | string | boolean
value: Value
}

const collectFromParamsFile = (
acc: { path: string; value: string | number | boolean }[],
acc: { path: string; value: Value }[],
key: string | undefined,
value: Value | ValueTree,
ancestors: string[] = []
) => {
const pathArray = [...ancestors, key].filter(Boolean) as string[]

if (typeof value === 'object') {
if (!Array.isArray(value) && typeof value === 'object') {
for (const [childKey, childValue] of Object.entries(value as ValueTree)) {
collectFromParamsFile(acc, childKey, childValue, pathArray)
}
Expand Down
15 changes: 13 additions & 2 deletions extension/src/experiments/model/queue/quickPick.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,38 @@ describe('pickAndModifyParams', () => {
const unchanged = { path: 'params.yaml:learning_rate', value: 2e-12 }
const initialUserResponse = [
{ path: 'params.yaml:dropout', value: 0.15 },
{ path: 'params.yaml:process.threshold', value: 0.86 }
{ path: 'params.yaml:process.threshold', value: 0.86 },
{ path: 'params.yaml:code_names', value: [0, 1, 2] }
]
mockedQuickPickManyValues.mockResolvedValueOnce(initialUserResponse)
const firstInput = '0.16'
const secondInput = '0.87'
const thirdInput = '[0,1,3]'
mockedGetInput.mockResolvedValueOnce(firstInput)
mockedGetInput.mockResolvedValueOnce(secondInput)
mockedGetInput.mockResolvedValueOnce(thirdInput)

const paramsToQueue = await pickAndModifyParams([
unchanged,
...initialUserResponse
])

expect(mockedGetInput).toBeCalledTimes(3)
expect(mockedGetInput).toBeCalledWith(
'Enter a Value for params.yaml:code_names',
'[0,1,2]'
)

expect(paramsToQueue).toStrictEqual([
'-S',
`params.yaml:dropout=${firstInput}`,
'-S',
`params.yaml:process.threshold=${secondInput}`,
'-S',
`params.yaml:code_names=${thirdInput}`,
'-S',
[unchanged.path, unchanged.value].join('=')
])
expect(mockedGetInput).toBeCalledTimes(2)
expect(mockedGetInput).toBeCalledTimes(3)
})
})
15 changes: 11 additions & 4 deletions extension/src/experiments/model/queue/quickPick.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ import { getInput } from '../../../vscode/inputBox'
import { Flag } from '../../../cli/constants'
import { definedAndNonEmpty } from '../../../util/array'
import { getEnterValueTitle, Title } from '../../../vscode/title'
import { Value } from '../../../cli/reader'

const standardizeValue = (value: Value): string =>
typeof value === 'object' ? JSON.stringify(value) : `${value}`

const pickParamsToModify = (params: Param[]): Thenable<Param[] | undefined> =>
quickPickManyValues<Param>(
params.map(param => ({
description: `${param.value}`,
description: standardizeValue(param.value),
label: param.path,
picked: false,
value: param
Expand All @@ -21,18 +25,21 @@ const pickNewParamValues = async (
): Promise<string[] | undefined> => {
const args: string[] = []
for (const { path, value } of paramsToModify) {
const input = await getInput(getEnterValueTitle(path), `${value}`)
const input = await getInput(
getEnterValueTitle(path),
standardizeValue(value)
)
if (input === undefined) {
return
}
args.push(Flag.SET_PARAM, [path, input.trim()].join('='))
args.push(Flag.SET_PARAM, [path, standardizeValue(input.trim())].join('='))
}
return args
}

const addUnchanged = (args: string[], unchanged: Param[]) => {
for (const { path, value } of unchanged) {
args.push(Flag.SET_PARAM, [path, value].join('='))
args.push(Flag.SET_PARAM, [path, standardizeValue(value)].join('='))
}

return args
Expand Down
10 changes: 10 additions & 0 deletions extension/src/test/fixtures/expShow/columns.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ const data: Column[] = [
parentPath: ColumnType.PARAMS,
path: joinColumnPath(ColumnType.PARAMS, 'params.yaml')
},
{
type: ColumnType.PARAMS,
hasChildren: false,
maxStringLength: 3,
name: 'code_names',
parentPath: joinColumnPath(ColumnType.PARAMS, 'params.yaml'),
path: joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'code_names'),
pathArray: [ColumnType.PARAMS, 'params.yaml', 'code_names'],
types: ['array']
},
{
type: ColumnType.PARAMS,
hasChildren: false,
Expand Down
15 changes: 15 additions & 0 deletions extension/src/test/fixtures/expShow/output.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 2,
learning_rate: 2.2e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -201,6 +202,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 5,
learning_rate: 2.1e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -311,6 +313,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 2,
learning_rate: 2e-12,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -420,6 +423,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 2,
learning_rate: 2e-12,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -529,6 +533,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 2,
learning_rate: 2e-12,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -608,6 +613,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 2,
learning_rate: 2.2e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -748,6 +754,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 2,
learning_rate: 2.2e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -857,6 +864,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 2,
learning_rate: 2.2e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -967,6 +975,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 5,
learning_rate: 2.1e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -1046,6 +1055,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 5,
learning_rate: 2.1e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -1155,6 +1165,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 5,
learning_rate: 2.1e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -1294,6 +1305,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 5,
learning_rate: 2.1e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -1403,6 +1415,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 5,
learning_rate: 2.1e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -1512,6 +1525,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 5,
learning_rate: 2.1e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down Expand Up @@ -1608,6 +1622,7 @@ const data: ExperimentsOutput = {
params: {
'params.yaml': {
data: {
code_names: [0, 1],
epochs: 5,
learning_rate: 2.1e-7,
dvc_logs_dir: 'dvc_logs',
Expand Down
Loading