diff --git a/extension/src/cli/reader.ts b/extension/src/cli/reader.ts index 5b39963e62..f26046e280 100644 --- a/extension/src/cli/reader.ts +++ b/extension/src/cli/reader.ts @@ -49,7 +49,14 @@ export type StatusesOrAlwaysChanged = StageOrFileStatuses | 'always changed' export type StatusOutput = Record -export type Value = string | number | boolean | null +export type Value = + | string + | number + | boolean + | null + | number[] + | string[] + | boolean[] export interface ValueTreeOrError { data?: ValueTree diff --git a/extension/src/experiments/columns/collect.test.ts b/extension/src/experiments/columns/collect.test.ts index cc59c726ec..bf65a513a6 100644 --- a/extension/src/experiments/columns/collect.test.ts +++ b/extension/src/experiments/columns/collect.test.ts @@ -416,6 +416,7 @@ describe('collectColumns', () => { joinColumnPath(ColumnType.METRICS, 'summary.json', 'val_loss'), joinColumnPath(ColumnType.METRICS, 'summary.json', 'val_accuracy'), joinColumnPath(ColumnType.PARAMS, 'params.yaml'), + joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'code_names'), joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'epochs'), joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'learning_rate'), joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'dvc_logs_dir'), diff --git a/extension/src/experiments/columns/collect.ts b/extension/src/experiments/columns/collect.ts index 33dcbbe586..3addc8823e 100644 --- a/extension/src/experiments/columns/collect.ts +++ b/extension/src/experiments/columns/collect.ts @@ -17,6 +17,9 @@ const getValueType = (value: Value) => { if (value === null) { return 'null' } + if (Array.isArray(value)) { + return 'array' + } return typeof value } diff --git a/extension/src/experiments/columns/tree.test.ts b/extension/src/experiments/columns/tree.test.ts index 0056c4a9d9..30ed1561eb 100644 --- a/extension/src/experiments/columns/tree.test.ts +++ b/extension/src/experiments/columns/tree.test.ts @@ -217,6 +217,14 @@ describe('ExperimentsColumnsTree', () => { path: paramsPath }) expect(grandChildren).toStrictEqual([ + { + collapsibleState: 0, + description: undefined, + dvcRoot: mockedDvcRoot, + iconPath: mockedSelectedCheckbox, + label: 'code_names', + path: appendColumnToPath(paramsPath, 'code_names') + }, { collapsibleState: 0, description: undefined, diff --git a/extension/src/experiments/columns/walk.ts b/extension/src/experiments/columns/walk.ts index 4b481226e8..f2509c81cb 100644 --- a/extension/src/experiments/columns/walk.ts +++ b/extension/src/experiments/columns/walk.ts @@ -27,7 +27,7 @@ const walkValueTree = ( ancestors: string[] = [] ) => { for (const [key, value] of Object.entries(tree)) { - if (value && typeof value === 'object') { + if (value && !Array.isArray(value) && typeof value === 'object') { walkValueTree(value, meta, onValue, [...ancestors, key]) } else { onValue(key, value, meta, ancestors) diff --git a/extension/src/experiments/model/queue/collect.test.ts b/extension/src/experiments/model/queue/collect.test.ts index 20567bb069..ea9ed79c93 100644 --- a/extension/src/experiments/model/queue/collect.test.ts +++ b/extension/src/experiments/model/queue/collect.test.ts @@ -7,6 +7,7 @@ describe('collectFlatExperimentParams', () => { it('should flatten the params into an array', () => { const params = collectFlatExperimentParams(rowsFixture[0].params) expect(params).toStrictEqual([ + { path: appendColumnToPath('params.yaml', 'code_names'), value: [0, 1] }, { path: appendColumnToPath('params.yaml', 'epochs'), value: 2 }, { path: appendColumnToPath('params.yaml', 'learning_rate'), diff --git a/extension/src/experiments/model/queue/collect.ts b/extension/src/experiments/model/queue/collect.ts index d27b130e27..816f446b5d 100644 --- a/extension/src/experiments/model/queue/collect.ts +++ b/extension/src/experiments/model/queue/collect.ts @@ -4,18 +4,18 @@ import { Columns } from '../../webview/contract' export type Param = { path: string - value: number | string | boolean + value: Value } const collectFromParamsFile = ( - acc: { path: string; value: string | number | boolean }[], + acc: { path: string; value: Value }[], key: string | undefined, value: Value | ValueTree, ancestors: string[] = [] ) => { const pathArray = [...ancestors, key].filter(Boolean) as string[] - if (typeof value === 'object') { + if (!Array.isArray(value) && typeof value === 'object') { for (const [childKey, childValue] of Object.entries(value as ValueTree)) { collectFromParamsFile(acc, childKey, childValue, pathArray) } diff --git a/extension/src/experiments/model/queue/quickPick.test.ts b/extension/src/experiments/model/queue/quickPick.test.ts index 27c2a86c68..044e22c44d 100644 --- a/extension/src/experiments/model/queue/quickPick.test.ts +++ b/extension/src/experiments/model/queue/quickPick.test.ts @@ -48,27 +48,38 @@ describe('pickAndModifyParams', () => { const unchanged = { path: 'params.yaml:learning_rate', value: 2e-12 } const initialUserResponse = [ { path: 'params.yaml:dropout', value: 0.15 }, - { path: 'params.yaml:process.threshold', value: 0.86 } + { path: 'params.yaml:process.threshold', value: 0.86 }, + { path: 'params.yaml:code_names', value: [0, 1, 2] } ] mockedQuickPickManyValues.mockResolvedValueOnce(initialUserResponse) const firstInput = '0.16' const secondInput = '0.87' + const thirdInput = '[0,1,3]' mockedGetInput.mockResolvedValueOnce(firstInput) mockedGetInput.mockResolvedValueOnce(secondInput) + mockedGetInput.mockResolvedValueOnce(thirdInput) const paramsToQueue = await pickAndModifyParams([ unchanged, ...initialUserResponse ]) + expect(mockedGetInput).toBeCalledTimes(3) + expect(mockedGetInput).toBeCalledWith( + 'Enter a Value for params.yaml:code_names', + '[0,1,2]' + ) + expect(paramsToQueue).toStrictEqual([ '-S', `params.yaml:dropout=${firstInput}`, '-S', `params.yaml:process.threshold=${secondInput}`, '-S', + `params.yaml:code_names=${thirdInput}`, + '-S', [unchanged.path, unchanged.value].join('=') ]) - expect(mockedGetInput).toBeCalledTimes(2) + expect(mockedGetInput).toBeCalledTimes(3) }) }) diff --git a/extension/src/experiments/model/queue/quickPick.ts b/extension/src/experiments/model/queue/quickPick.ts index 708107912c..ebef9c2851 100644 --- a/extension/src/experiments/model/queue/quickPick.ts +++ b/extension/src/experiments/model/queue/quickPick.ts @@ -4,11 +4,15 @@ import { getInput } from '../../../vscode/inputBox' import { Flag } from '../../../cli/constants' import { definedAndNonEmpty } from '../../../util/array' import { getEnterValueTitle, Title } from '../../../vscode/title' +import { Value } from '../../../cli/reader' + +const standardizeValue = (value: Value): string => + typeof value === 'object' ? JSON.stringify(value) : `${value}` const pickParamsToModify = (params: Param[]): Thenable => quickPickManyValues( params.map(param => ({ - description: `${param.value}`, + description: standardizeValue(param.value), label: param.path, picked: false, value: param @@ -21,18 +25,21 @@ const pickNewParamValues = async ( ): Promise => { const args: string[] = [] for (const { path, value } of paramsToModify) { - const input = await getInput(getEnterValueTitle(path), `${value}`) + const input = await getInput( + getEnterValueTitle(path), + standardizeValue(value) + ) if (input === undefined) { return } - args.push(Flag.SET_PARAM, [path, input.trim()].join('=')) + args.push(Flag.SET_PARAM, [path, standardizeValue(input.trim())].join('=')) } return args } const addUnchanged = (args: string[], unchanged: Param[]) => { for (const { path, value } of unchanged) { - args.push(Flag.SET_PARAM, [path, value].join('=')) + args.push(Flag.SET_PARAM, [path, standardizeValue(value)].join('=')) } return args diff --git a/extension/src/test/fixtures/expShow/columns.ts b/extension/src/test/fixtures/expShow/columns.ts index 372b2bdb5f..d37c0d700e 100644 --- a/extension/src/test/fixtures/expShow/columns.ts +++ b/extension/src/test/fixtures/expShow/columns.ts @@ -67,6 +67,16 @@ const data: Column[] = [ parentPath: ColumnType.PARAMS, path: joinColumnPath(ColumnType.PARAMS, 'params.yaml') }, + { + type: ColumnType.PARAMS, + hasChildren: false, + maxStringLength: 3, + name: 'code_names', + parentPath: joinColumnPath(ColumnType.PARAMS, 'params.yaml'), + path: joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'code_names'), + pathArray: [ColumnType.PARAMS, 'params.yaml', 'code_names'], + types: ['array'] + }, { type: ColumnType.PARAMS, hasChildren: false, diff --git a/extension/src/test/fixtures/expShow/output.ts b/extension/src/test/fixtures/expShow/output.ts index 8ee0232aec..f831b0a4ba 100644 --- a/extension/src/test/fixtures/expShow/output.ts +++ b/extension/src/test/fixtures/expShow/output.ts @@ -91,6 +91,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 2, learning_rate: 2.2e-7, dvc_logs_dir: 'dvc_logs', @@ -201,6 +202,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -311,6 +313,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 2, learning_rate: 2e-12, dvc_logs_dir: 'dvc_logs', @@ -420,6 +423,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 2, learning_rate: 2e-12, dvc_logs_dir: 'dvc_logs', @@ -529,6 +533,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 2, learning_rate: 2e-12, dvc_logs_dir: 'dvc_logs', @@ -608,6 +613,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 2, learning_rate: 2.2e-7, dvc_logs_dir: 'dvc_logs', @@ -748,6 +754,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 2, learning_rate: 2.2e-7, dvc_logs_dir: 'dvc_logs', @@ -857,6 +864,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 2, learning_rate: 2.2e-7, dvc_logs_dir: 'dvc_logs', @@ -967,6 +975,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1046,6 +1055,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1155,6 +1165,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1294,6 +1305,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1403,6 +1415,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1512,6 +1525,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1608,6 +1622,7 @@ const data: ExperimentsOutput = { params: { 'params.yaml': { data: { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', diff --git a/extension/src/test/fixtures/expShow/rows.ts b/extension/src/test/fixtures/expShow/rows.ts index dfdf2f8de9..ed971eebb7 100644 --- a/extension/src/test/fixtures/expShow/rows.ts +++ b/extension/src/test/fixtures/expShow/rows.ts @@ -92,6 +92,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 2, learning_rate: 2.2e-7, dvc_logs_dir: 'dvc_logs', @@ -196,6 +197,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -304,6 +306,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 2, learning_rate: 2e-12, dvc_logs_dir: 'dvc_logs', @@ -410,6 +413,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 2, learning_rate: 2e-12, dvc_logs_dir: 'dvc_logs', @@ -517,6 +521,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 2, learning_rate: 2e-12, dvc_logs_dir: 'dvc_logs', @@ -629,6 +634,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 2, learning_rate: 2.2e-7, dvc_logs_dir: 'dvc_logs', @@ -735,6 +741,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 2, learning_rate: 2.2e-7, dvc_logs_dir: 'dvc_logs', @@ -842,6 +849,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 2, learning_rate: 2.2e-7, dvc_logs_dir: 'dvc_logs', @@ -954,6 +962,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1060,6 +1069,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1167,6 +1177,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1274,6 +1285,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1381,6 +1393,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1488,6 +1501,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', @@ -1585,6 +1599,7 @@ const data: Row[] = [ }, params: { 'params.yaml': { + code_names: [0, 1], epochs: 5, learning_rate: 2.1e-7, dvc_logs_dir: 'dvc_logs', diff --git a/extension/src/test/suite/extension.test.ts b/extension/src/test/suite/extension.test.ts index 3e0e564bbc..ffdcaec8a5 100644 --- a/extension/src/test/suite/extension.test.ts +++ b/extension/src/test/suite/extension.test.ts @@ -328,7 +328,7 @@ suite('Extension Test Suite', () => { msPythonInstalled: true, msPythonUsed: false, noCheckpoints: 0, - params: 8, + params: 9, pythonPathUsed: false, templates: 3, tracked: 15, diff --git a/webview/src/experiments/util/buildDynamicColumns.test.ts b/webview/src/experiments/util/buildDynamicColumns.test.ts index 9c13f8441c..ae43420639 100644 --- a/webview/src/experiments/util/buildDynamicColumns.test.ts +++ b/webview/src/experiments/util/buildDynamicColumns.test.ts @@ -54,6 +54,22 @@ describe('buildDynamicColumns', () => { }, { columns: [ + { + columns: [ + { + id: joinColumnPath( + ColumnType.PARAMS, + 'params.yaml', + 'code_names' + ) + } + ], + id: joinColumnPath( + ColumnType.PARAMS, + 'params.yaml', + 'code_names_previous_placeholder' + ) + }, { columns: [ { diff --git a/webview/src/experiments/util/buildDynamicColumns.tsx b/webview/src/experiments/util/buildDynamicColumns.tsx index 3de1894858..39518c3568 100644 --- a/webview/src/experiments/util/buildDynamicColumns.tsx +++ b/webview/src/experiments/util/buildDynamicColumns.tsx @@ -83,8 +83,13 @@ const Header: React.FC<{ column: TableColumn }> = ({ } const buildAccessor: (valuePath: string[]) => Accessor = - pathArray => originalRow => - get(originalRow, pathArray) + pathArray => originalRow => { + const value = get(originalRow, pathArray) + if (!Array.isArray(value)) { + return value + } + return `[${value.join(', ')}]` + } const buildDynamicColumns = ( properties: Column[],