Skip to content

Commit d600a23

Browse files
authored
Merge branch 'master' into fix-stub
2 parents 7d47629 + b6b2343 commit d600a23

File tree

18 files changed

+420
-46
lines changed

18 files changed

+420
-46
lines changed

extension/package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,7 @@
10631063
"chokidar": "^3.5.2",
10641064
"execa": "^5.1.1",
10651065
"fs-extra": "^10.0.0",
1066+
"js-yaml": "^4.1.0",
10661067
"lodash.clonedeep": "^4.5.0",
10671068
"lodash.get": "^4.4.2",
10681069
"lodash.isempty": "^4.4.0",
@@ -1078,6 +1079,7 @@
10781079
"@types/copy-webpack-plugin": "^8.0.0",
10791080
"@types/fs-extra": "^9.0.13",
10801081
"@types/jest": "^27.4.0",
1082+
"@types/js-yaml": "^4.0.5",
10811083
"@types/lodash.clonedeep": "^4.5.6",
10821084
"@types/lodash.get": "^4.4.6",
10831085
"@types/lodash.isempty": "^4.4.6",
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import { collectHasCheckpoints } from './collect'
2+
import { PartialDvcYaml } from '../../fileSystem'
3+
4+
describe('collectHasCheckpoints', () => {
5+
it('should correctly identify that the demo project has checkpoints', () => {
6+
const hasCheckpoints = collectHasCheckpoints({
7+
stages: {
8+
train: {
9+
cmd: 'python train.py',
10+
deps: ['data/MNIST', 'train.py'],
11+
live: { logs: { html: true, summary: true } },
12+
outs: [{ 'model.pt': { checkpoint: true } }],
13+
params: ['seed', 'lr', 'weight_decay'],
14+
plots: [
15+
'plots',
16+
{
17+
'predictions.json': {
18+
cache: false,
19+
template: 'confusion',
20+
x: 'actual',
21+
y: 'predicted'
22+
}
23+
}
24+
]
25+
}
26+
}
27+
} as PartialDvcYaml)
28+
expect(hasCheckpoints).toBe(true)
29+
})
30+
31+
it('should correctly classify a dvc.yaml without checkpoint', () => {
32+
const hasCheckpoints = collectHasCheckpoints({
33+
stages: {
34+
extract: {
35+
cmd: 'tar -xzf data/images.tar.gz --directory data',
36+
deps: ['data/images.tar.gz'],
37+
outs: [{ 'data/images/': { cache: false } }]
38+
},
39+
train: {
40+
cmd: 'python3 src/train.py',
41+
deps: ['data/images/', 'src/train.py'],
42+
live: { logs: { html: true, summary: true } },
43+
metrics: [{ 'metrics.json': { cache: false } }],
44+
outs: ['models/model.h5'],
45+
params: ['model.conv_units', 'train.epochs'],
46+
plots: [{ 'logs.csv': { cache: false } }]
47+
}
48+
}
49+
} as PartialDvcYaml)
50+
51+
expect(hasCheckpoints).toBe(false)
52+
})
53+
54+
it('should correctly classify a more complex dvc.yaml without checkpoint', () => {
55+
const hasCheckpoints = collectHasCheckpoints({
56+
stages: {
57+
evaluate: {
58+
cmd: 'python src/evaluate.py model.pkl data/features scores.json prc.json roc.json',
59+
deps: ['data/features', 'model.pkl', 'src/evaluate.py'],
60+
metrics: [{ 'scores.json': { cache: false } }],
61+
plots: [
62+
{ 'prc.json': { cache: false, x: 'recall', y: 'precision' } },
63+
{ 'roc.json': { cache: false, x: 'fpr', y: 'tpr' } }
64+
]
65+
},
66+
featurize: {
67+
cmd: 'python src/featurization.py data/prepared data/features',
68+
deps: ['data/prepared', 'src/featurization.py'],
69+
outs: ['data/features'],
70+
params: ['featurize.max_features', 'featurize.ngrams']
71+
},
72+
prepare: {
73+
cmd: 'python src/prepare.py data/data.xml',
74+
deps: ['data/data.xml', 'src/prepare.py'],
75+
outs: ['data/prepared'],
76+
params: ['prepare.seed', 'prepare.split']
77+
},
78+
train: {
79+
cmd: 'python src/train.py data/features model.pkl',
80+
deps: ['data/features', 'src/train.py'],
81+
outs: ['model.pkl'],
82+
params: ['train.min_split', 'train.n_est', 'train.seed']
83+
}
84+
}
85+
} as PartialDvcYaml)
86+
87+
expect(hasCheckpoints).toBe(false)
88+
})
89+
})
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import { PartialDvcYaml } from '../../fileSystem'
2+
3+
export const collectHasCheckpoints = (yaml: PartialDvcYaml): boolean => {
4+
return !!yaml.stages.train.outs.find(out => {
5+
if (typeof out === 'string') {
6+
return false
7+
}
8+
9+
if (Object.values(out).find(file => file?.checkpoint)) {
10+
return true
11+
}
12+
})
13+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import { join } from 'path'
2+
import { CheckpointsModel } from './model'
3+
import { dvcDemoPath } from '../../test/util'
4+
5+
describe('CheckpointsModel', () => {
6+
it('should keep a record of yaml files that have checkpoints', () => {
7+
const experimentCheckpointsModel = new CheckpointsModel()
8+
expect(experimentCheckpointsModel.hasCheckpoints()).toBe(false)
9+
10+
const rootYamlHasCheckpoints = {
11+
path: join(dvcDemoPath, 'dvc.yaml'),
12+
yaml: {
13+
stages: {
14+
train: {
15+
outs: [{ 'model.pt': { checkpoint: true } }]
16+
}
17+
}
18+
}
19+
}
20+
21+
const extraYamlHasCheckpoints = {
22+
path: join(dvcDemoPath, 'extra', 'dvc.yaml'),
23+
yaml: {
24+
stages: {
25+
train: {
26+
outs: [{ 'extra-model.pt': { checkpoint: true } }]
27+
}
28+
}
29+
}
30+
}
31+
32+
const rootYamlNoCheckpoints = {
33+
path: join(dvcDemoPath, 'dvc.yaml'),
34+
yaml: {
35+
stages: { train: { outs: ['model.pt'] } }
36+
}
37+
}
38+
39+
const extraYamlNoCheckpoints = {
40+
path: join(dvcDemoPath, 'extra', 'dvc.yaml'),
41+
yaml: {
42+
stages: { train: { outs: ['extra-model.pt'] } }
43+
}
44+
}
45+
46+
experimentCheckpointsModel.transformAndSet(rootYamlHasCheckpoints)
47+
expect(experimentCheckpointsModel.hasCheckpoints()).toBe(true)
48+
49+
experimentCheckpointsModel.transformAndSet(extraYamlHasCheckpoints)
50+
expect(experimentCheckpointsModel.hasCheckpoints()).toBe(true)
51+
52+
experimentCheckpointsModel.transformAndSet(rootYamlNoCheckpoints)
53+
expect(experimentCheckpointsModel.hasCheckpoints()).toBe(true)
54+
55+
experimentCheckpointsModel.transformAndSet(extraYamlNoCheckpoints)
56+
57+
expect(experimentCheckpointsModel.hasCheckpoints()).toBe(false)
58+
})
59+
})
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import { Disposable } from '@hediet/std/disposable'
2+
import { collectHasCheckpoints } from './collect'
3+
import { PartialDvcYaml } from '../../fileSystem'
4+
import { definedAndNonEmpty, uniqueValues } from '../../util/array'
5+
6+
export class CheckpointsModel {
7+
public dispose = Disposable.fn()
8+
9+
private yamlWithCheckpoints: string[] = []
10+
11+
public hasCheckpoints() {
12+
return definedAndNonEmpty(this.yamlWithCheckpoints)
13+
}
14+
15+
public transformAndSet(data: { path: string; yaml: PartialDvcYaml }) {
16+
const { path, yaml } = data
17+
const hasCheckpoints = collectHasCheckpoints(yaml)
18+
19+
if (hasCheckpoints) {
20+
this.yamlWithCheckpoints = uniqueValues([
21+
...this.yamlWithCheckpoints,
22+
path
23+
])
24+
} else {
25+
this.yamlWithCheckpoints = this.yamlWithCheckpoints.filter(
26+
file => file !== path
27+
)
28+
}
29+
}
30+
}

extension/src/experiments/index.ts

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
} from './model/filterBy/quickPick'
1010
import { pickSortsToRemove, pickSortToAdd } from './model/sortBy/quickPick'
1111
import { MetricsAndParamsModel } from './metricsAndParams/model'
12+
import { CheckpointsModel } from './checkpoints/model'
1213
import { ExperimentsData } from './data'
1314
import { TableData } from './webview/contract'
1415
import { ResourceLocator } from '../resourceLocator'
@@ -22,17 +23,20 @@ import {
2223
MessageFromWebviewType
2324
} from '../webview/contract'
2425
import { Logger } from '../common/logger'
26+
import { FileSystemData } from '../fileSystem/data'
2527

2628
export class Experiments extends BaseRepository<TableData> {
2729
public readonly onDidChangeExperiments: Event<ExperimentsOutput | void>
2830
public readonly onDidChangeMetricsOrParams: Event<void>
2931

3032
public readonly viewKey = ViewKey.EXPERIMENTS
3133

32-
private readonly data: ExperimentsData
34+
private readonly cliData: ExperimentsData
35+
private readonly fileSystemData: FileSystemData
3336

3437
private readonly experiments: ExperimentsModel
3538
private readonly metricsAndParams: MetricsAndParamsModel
39+
private readonly checkpoints: CheckpointsModel
3640

3741
private readonly experimentsChanged = this.dispose.track(
3842
new EventEmitter<ExperimentsOutput | void>()
@@ -48,7 +52,8 @@ export class Experiments extends BaseRepository<TableData> {
4852
updatesPaused: EventEmitter<boolean>,
4953
resourceLocator: ResourceLocator,
5054
workspaceState: Memento,
51-
data?: ExperimentsData
55+
cliData?: ExperimentsData,
56+
fileSystemData?: FileSystemData
5257
) {
5358
super(dvcRoot, resourceLocator.beaker)
5459

@@ -63,11 +68,22 @@ export class Experiments extends BaseRepository<TableData> {
6368
new MetricsAndParamsModel(dvcRoot, workspaceState)
6469
)
6570

66-
this.data = this.dispose.track(
67-
data || new ExperimentsData(dvcRoot, internalCommands, updatesPaused)
71+
this.checkpoints = this.dispose.track(new CheckpointsModel())
72+
73+
this.cliData = this.dispose.track(
74+
cliData || new ExperimentsData(dvcRoot, internalCommands, updatesPaused)
6875
)
6976

70-
this.dispose.track(this.data.onDidUpdate(data => this.setState(data)))
77+
this.fileSystemData = this.dispose.track(
78+
fileSystemData || new FileSystemData(dvcRoot)
79+
)
80+
81+
this.dispose.track(this.cliData.onDidUpdate(data => this.setState(data)))
82+
this.dispose.track(
83+
this.fileSystemData.onDidUpdate(data =>
84+
this.checkpoints.transformAndSet(data)
85+
)
86+
)
7187

7288
this.handleMessageFromWebview()
7389

@@ -81,11 +97,11 @@ export class Experiments extends BaseRepository<TableData> {
8197
}
8298

8399
public update() {
84-
return this.data.managedUpdate()
100+
return this.cliData.managedUpdate()
85101
}
86102

87103
public forceUpdate() {
88-
return this.data.forceUpdate()
104+
return this.cliData.forceUpdate()
89105
}
90106

91107
public async setState(data: ExperimentsOutput) {
@@ -97,6 +113,10 @@ export class Experiments extends BaseRepository<TableData> {
97113
return this.notifyChanged(data)
98114
}
99115

116+
public hasCheckpoints() {
117+
return this.checkpoints.hasCheckpoints()
118+
}
119+
100120
public getChildMetricsOrParams(path?: string) {
101121
return this.metricsAndParams.getChildren(path)
102122
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import { Event, EventEmitter } from 'vscode'
2+
import { Disposable } from '@hediet/std/disposable'
3+
import { Deferred } from '@hediet/std/synchronization'
4+
import { isSameOrChild, loadYaml, PartialDvcYaml } from '..'
5+
import { findFiles } from '../workspace'
6+
import { join } from '../../test/util/path'
7+
import { createFileSystemWatcher } from '../watcher'
8+
9+
export class FileSystemData {
10+
public readonly dispose = Disposable.fn()
11+
12+
public readonly onDidUpdate: Event<{ path: string; yaml: PartialDvcYaml }>
13+
14+
private readonly dvcRoot: string
15+
16+
private readonly updated = this.dispose.track(
17+
new EventEmitter<{ path: string; yaml: PartialDvcYaml }>()
18+
)
19+
20+
private readonly deferred = new Deferred()
21+
private readonly initialized = this.deferred.promise
22+
23+
constructor(dvcRoot: string) {
24+
this.dvcRoot = dvcRoot
25+
this.onDidUpdate = this.updated.event
26+
27+
this.watchDvcYaml()
28+
this.initialize()
29+
}
30+
31+
public isReady() {
32+
return this.initialized
33+
}
34+
35+
private async initialize() {
36+
const files = await findFiles(join('**', 'dvc.yaml'))
37+
const filesInRepo = files.filter(file => isSameOrChild(this.dvcRoot, file))
38+
39+
filesInRepo.map(path => {
40+
const yaml = loadYaml<PartialDvcYaml>(path)
41+
if (yaml) {
42+
this.updated.fire({ path, yaml })
43+
}
44+
})
45+
46+
this.deferred.resolve()
47+
}
48+
49+
private watchDvcYaml() {
50+
this.dispose.track(
51+
createFileSystemWatcher(join(this.dvcRoot, '**', 'dvc.yaml'), path => {
52+
if (!path) {
53+
return
54+
}
55+
const yaml = loadYaml<PartialDvcYaml>(path)
56+
if (yaml) {
57+
this.updated.fire({ path, yaml })
58+
}
59+
})
60+
)
61+
}
62+
}

0 commit comments

Comments
 (0)