Skip to content

Commit

Permalink
Ensure experiment summary info (columns) are always availabe in the e…
Browse files Browse the repository at this point in the history
…xperiment table data
  • Loading branch information
mattseddon committed Aug 1, 2023
1 parent e6e0dac commit 5acf18a
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 159 deletions.
35 changes: 35 additions & 0 deletions extension/src/experiments/columns/collect/order.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { Column, ColumnType } from '../../webview/contract'
import { EXPERIMENT_COLUMN_ID } from '../constants'

export const collectColumnOrder = async (
existingColumnOrder: string[],
terminalNodes: Column[]
): Promise<string[]> => {
const acc: { [columnType: string]: string[] } = {
[ColumnType.DEPS]: [],
[ColumnType.METRICS]: [],
[ColumnType.PARAMS]: [],
[ColumnType.TIMESTAMP]: []
}
for (const { type, path } of terminalNodes) {
if (existingColumnOrder.includes(path)) {
continue
}
acc[type].push(path)
}

// eslint-disable-next-line etc/no-assign-mutated-array
await Promise.all([acc.metrics.sort(), acc.params.sort(), acc.deps.sort()])

if (!existingColumnOrder.includes(EXPERIMENT_COLUMN_ID)) {
existingColumnOrder.unshift(EXPERIMENT_COLUMN_ID)
}

return [
...existingColumnOrder,
...acc.timestamp,
...acc.metrics,
...acc.params,
...acc.deps
]
}
81 changes: 69 additions & 12 deletions extension/src/experiments/columns/model.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { join } from 'path'
import { Disposable, Disposer } from '@hediet/std/disposable'
import { ColumnsModel } from './model'
import { appendColumnToPath, buildMetricOrParamPath } from './paths'
import { timestampColumn } from './constants'
import { EXPERIMENT_COLUMN_ID, timestampColumn } from './constants'
import { buildMockMemento } from '../../test/util'
import { generateTestExpShowOutput } from '../../test/util/experiments'
import { Status } from '../../path/selection/model'
Expand Down Expand Up @@ -258,7 +259,7 @@ describe('ColumnsModel', () => {
expect(model.getColumnOrder()).toStrictEqual(persistedState)
})

it('should return the first three visible columns for both metrics and params from the persisted state', async () => {
it('should return the first three visible columns for both metrics and params from the persisted state (first)', async () => {
const persistedState = [
'id',
'Created',
Expand Down Expand Up @@ -287,7 +288,8 @@ describe('ColumnsModel', () => {
'params:params.yaml:process.threshold',
'params:params.yaml:process.test_arg',
'metrics:summary.json:loss',
'metrics:summary.json:accuracy'
'metrics:summary.json:accuracy',
'metrics:summary.json:val_accuracy'
])

model.toggleStatus('params:params.yaml:dvc_logs_dir')
Expand All @@ -297,10 +299,65 @@ describe('ColumnsModel', () => {
'params:params.yaml:process.test_arg',
'params:params.yaml:dropout',
'metrics:summary.json:loss',
'metrics:summary.json:accuracy'
'metrics:summary.json:accuracy',
'metrics:summary.json:val_accuracy'
])
})

it('should not add a param that is no longer present to the summary column order', async () => {
const persistedState = [
'id',
'Created',
'params:params.yaml:an-old-params'
]

const model = new ColumnsModel(
exampleDvcRoot,
buildMockMemento({
[PersistenceKey.METRICS_AND_PARAMS_COLUMN_ORDER + exampleDvcRoot]:
persistedState
}),
mockedColumnsOrderOrStatusChanged
)
model.toggleStatus('params:params.yaml:an-old-params')
await model.transformAndSet(outputFixture)

expect(model.getSummaryColumnOrder()).toStrictEqual([
join('params:nested', 'params.yaml:test'),
'params:params.yaml:code_names',
'params:params.yaml:dropout',
'metrics:summary.json:accuracy',
'metrics:summary.json:loss',
'metrics:summary.json:val_accuracy'
])
})

it('should add to the persisted state when there are columns that were not found', async () => {
const persistedState = ['params:params.yaml:dvc_logs_dir']

const model = new ColumnsModel(
exampleDvcRoot,
buildMockMemento({
[PersistenceKey.METRICS_AND_PARAMS_COLUMN_ORDER + exampleDvcRoot]:
persistedState
}),
mockedColumnsOrderOrStatusChanged
)
await model.transformAndSet(outputFixture)

expect(model.getSummaryColumnOrder()).toStrictEqual([
'params:params.yaml:dvc_logs_dir',
join('params:nested', 'params.yaml:test'),
'params:params.yaml:code_names',
'metrics:summary.json:accuracy',
'metrics:summary.json:loss',
'metrics:summary.json:val_accuracy'
])

const [id] = model.getColumnOrder()
expect(id).toStrictEqual(EXPERIMENT_COLUMN_ID)
})

it('should return the first three metric and param columns (none hidden) collected from data if state is empty', async () => {
const model = new ColumnsModel(
exampleDvcRoot,
Expand All @@ -310,23 +367,23 @@ describe('ColumnsModel', () => {
await model.transformAndSet(outputFixture)

expect(model.getSummaryColumnOrder()).toStrictEqual([
join('params:nested', 'params.yaml:test'),
'params:params.yaml:code_names',
'params:params.yaml:epochs',
'params:params.yaml:learning_rate',
'metrics:summary.json:loss',
'params:params.yaml:dropout',
'metrics:summary.json:accuracy',
'metrics:summary.json:val_loss'
'metrics:summary.json:loss',
'metrics:summary.json:val_accuracy'
])

model.toggleStatus('params:params.yaml:code_names')

expect(model.getSummaryColumnOrder()).toStrictEqual([
'params:params.yaml:epochs',
'params:params.yaml:learning_rate',
join('params:nested', 'params.yaml:test'),
'params:params.yaml:dropout',
'params:params.yaml:dvc_logs_dir',
'metrics:summary.json:loss',
'metrics:summary.json:accuracy',
'metrics:summary.json:val_loss'
'metrics:summary.json:loss',
'metrics:summary.json:val_accuracy'
])
})

Expand Down
51 changes: 16 additions & 35 deletions extension/src/experiments/columns/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
collectRelativeMetricsFiles,
collectParamsFiles
} from './collect'
import { EXPERIMENT_COLUMN_ID, timestampColumn } from './constants'
import { collectColumnOrder } from './collect/order'
import { SummaryAcc, collectFromColumnOrder } from './util'
import { Column, ColumnType } from '../webview/contract'
import { ExpShowOutput } from '../../cli/dvc/contract'
Expand Down Expand Up @@ -152,38 +152,13 @@ export class ColumnsModel extends PathSelectionModel<Column> {
)
}

private findChildrenColumns(
parent: string,
columns: Column[],
childrenColumns: string[]
) {
const filteredColumns = columns.filter(
({ parentPath }) => parentPath === parent
private async setColumnOrderFromData(terminalNodes: Column[]) {
const extendedColumnOrder = await collectColumnOrder(
this.columnOrderState,
terminalNodes
)
for (const column of filteredColumns) {
if (column.hasChildren) {
this.findChildrenColumns(column.path, columns, childrenColumns)
} else {
childrenColumns.push(column.path)
}
}
}

private getColumnsFromType(type: string): string[] {
const childrenColumns: string[] = []
const dataWithType = this.data.filter(({ path }) => path.startsWith(type))
this.findChildrenColumns(type, dataWithType, childrenColumns)
return childrenColumns
}

private getColumnOrderFromData() {
return [
EXPERIMENT_COLUMN_ID,
timestampColumn.path,
...this.getColumnsFromType(ColumnType.METRICS),
...this.getColumnsFromType(ColumnType.PARAMS),
...this.getColumnsFromType(ColumnType.DEPS)
]
this.setColumnOrder(extendedColumnOrder)
}

private async transformAndSetColumns(data: ExpShowOutput) {
Expand All @@ -197,12 +172,18 @@ export class ColumnsModel extends PathSelectionModel<Column> {

this.data = columns

if (this.columnOrderState.length === 0) {
this.setColumnOrder(this.getColumnOrderFromData())
}

this.paramsFiles = paramsFiles
this.relativeMetricsFiles = relativeMetricsFiles

const selectedColumns = this.getTerminalNodes().filter(
({ selected }) => selected
)

for (const { path } of selectedColumns) {
if (!this.columnOrderState.includes(path)) {
return this.setColumnOrderFromData(selectedColumns)
}
}
}

private transformAndSetChanges(data: ExpShowOutput) {
Expand Down
12 changes: 12 additions & 0 deletions extension/src/path/selection/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,22 @@ export abstract class PathSelectionModel<
}

protected setNewStatuses(data: { path: string }[]) {
const paths = new Set<string>()
for (const { path } of data) {
if (this.status[path] === undefined) {
this.status[path] = Status.SELECTED
}
paths.add(path)
}

this.removeMissingSelected(paths)
}

private removeMissingSelected(paths: Set<string>) {
for (const [path, status] of Object.entries(this.status)) {
if (!paths.has(path) && status === Status.SELECTED) {
delete this.status[path]
}
}
}

Expand Down
22 changes: 11 additions & 11 deletions extension/src/test/fixtures/expShow/base/columns.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,27 @@ const nestedParamsFile = join('nested', 'params.yaml')
export const dataColumnOrder: string[] = [
'id',
'Created',
'metrics:summary.json:loss',
'metrics:summary.json:accuracy',
'metrics:summary.json:val_loss',
'metrics:summary.json:loss',
'metrics:summary.json:val_accuracy',
'metrics:summary.json:val_loss',
join('params:nested', 'params.yaml:test'),
'params:params.yaml:code_names',
'params:params.yaml:dropout',
'params:params.yaml:dvc_logs_dir',
'params:params.yaml:epochs',
'params:params.yaml:learning_rate',
'params:params.yaml:dvc_logs_dir',
'params:params.yaml:log_file',
'params:params.yaml:dropout',
'params:params.yaml:process.threshold',
'params:params.yaml:process.test_arg',
join('params:nested', 'params.yaml:test'),
'params:params.yaml:process.threshold',
join('deps:data', 'data.xml'),
join('deps:data', 'prepared'),
join('deps:data', 'features'),
join('deps:src', 'prepare.py'),
join('deps:src', 'featurization.py'),
join('deps:src', 'train.py'),
join('deps:data', 'prepared'),
'deps:model.pkl',
join('deps:src', 'evaluate.py'),
'deps:model.pkl'
join('deps:src', 'featurization.py'),
join('deps:src', 'prepare.py'),
join('deps:src', 'train.py')
]

const data: Column[] = [
Expand Down
Loading

0 comments on commit 5acf18a

Please sign in to comment.