Skip to content

Commit 9f4f65a

Browse files
committed
add experiment checkpoints model
1 parent 456bf99 commit 9f4f65a

File tree

5 files changed

+196
-0
lines changed

5 files changed

+196
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import { collectHasCheckpoints } from './collect'
2+
import { PartialDvcYaml } from '../../fileSystem'
3+
4+
describe('collectHasCheckpoints', () => {
5+
it('should correctly identify that the demo project has checkpoints', () => {
6+
const hasCheckpoints = collectHasCheckpoints({
7+
stages: {
8+
train: {
9+
cmd: 'python train.py',
10+
deps: ['data/MNIST', 'train.py'],
11+
live: { logs: { html: true, summary: true } },
12+
outs: [{ 'model.pt': { checkpoint: true } }],
13+
params: ['seed', 'lr', 'weight_decay'],
14+
plots: [
15+
'plots',
16+
{
17+
'predictions.json': {
18+
cache: false,
19+
template: 'confusion',
20+
x: 'actual',
21+
y: 'predicted'
22+
}
23+
}
24+
]
25+
}
26+
}
27+
} as PartialDvcYaml)
28+
expect(hasCheckpoints).toBe(true)
29+
})
30+
31+
it('should correctly classify a dvc.yaml without checkpoint', () => {
32+
const hasCheckpoints = collectHasCheckpoints({
33+
stages: {
34+
extract: {
35+
cmd: 'tar -xzf data/images.tar.gz --directory data',
36+
deps: ['data/images.tar.gz'],
37+
outs: [{ 'data/images/': { cache: false } }]
38+
},
39+
train: {
40+
cmd: 'python3 src/train.py',
41+
deps: ['data/images/', 'src/train.py'],
42+
live: { logs: { html: true, summary: true } },
43+
metrics: [{ 'metrics.json': { cache: false } }],
44+
outs: ['models/model.h5'],
45+
params: ['model.conv_units', 'train.epochs'],
46+
plots: [{ 'logs.csv': { cache: false } }]
47+
}
48+
}
49+
} as PartialDvcYaml)
50+
51+
expect(hasCheckpoints).toBe(false)
52+
})
53+
54+
it('should correctly classify a more complex dvc.yaml without checkpoint', () => {
55+
const hasCheckpoints = collectHasCheckpoints({
56+
stages: {
57+
evaluate: {
58+
cmd: 'python src/evaluate.py model.pkl data/features scores.json prc.json roc.json',
59+
deps: ['data/features', 'model.pkl', 'src/evaluate.py'],
60+
metrics: [{ 'scores.json': { cache: false } }],
61+
plots: [
62+
{ 'prc.json': { cache: false, x: 'recall', y: 'precision' } },
63+
{ 'roc.json': { cache: false, x: 'fpr', y: 'tpr' } }
64+
]
65+
},
66+
featurize: {
67+
cmd: 'python src/featurization.py data/prepared data/features',
68+
deps: ['data/prepared', 'src/featurization.py'],
69+
outs: ['data/features'],
70+
params: ['featurize.max_features', 'featurize.ngrams']
71+
},
72+
prepare: {
73+
cmd: 'python src/prepare.py data/data.xml',
74+
deps: ['data/data.xml', 'src/prepare.py'],
75+
outs: ['data/prepared'],
76+
params: ['prepare.seed', 'prepare.split']
77+
},
78+
train: {
79+
cmd: 'python src/train.py data/features model.pkl',
80+
deps: ['data/features', 'src/train.py'],
81+
outs: ['model.pkl'],
82+
params: ['train.min_split', 'train.n_est', 'train.seed']
83+
}
84+
}
85+
} as PartialDvcYaml)
86+
87+
expect(hasCheckpoints).toBe(false)
88+
})
89+
})
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import { PartialDvcYaml } from '../../fileSystem'
2+
3+
export const collectHasCheckpoints = (yaml: PartialDvcYaml): boolean => {
4+
return !!yaml.stages.train.outs.find(out => {
5+
if (typeof out === 'string') {
6+
return false
7+
}
8+
9+
if (Object.values(out).find(file => file?.checkpoint)) {
10+
return true
11+
}
12+
})
13+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import { join } from 'path'
2+
import { ExperimentCheckpointsModel } from './model'
3+
import { dvcDemoPath } from '../../test/util'
4+
5+
describe('ExperimentCheckpointsModel', () => {
6+
it('should keep a record of yaml files that have checkpoints', () => {
7+
const experimentCheckpointsModel = new ExperimentCheckpointsModel()
8+
expect(experimentCheckpointsModel.hasCheckpoints()).toBe(false)
9+
10+
const rootYamlHasCheckpoints = {
11+
path: join(dvcDemoPath, 'dvc.yaml'),
12+
yaml: {
13+
stages: {
14+
train: {
15+
outs: [{ 'model.pt': { checkpoint: true } }]
16+
}
17+
}
18+
}
19+
}
20+
21+
const extraYamlHasCheckpoints = {
22+
path: join(dvcDemoPath, 'extra', 'dvc.yaml'),
23+
yaml: {
24+
stages: {
25+
train: {
26+
outs: [{ 'extra-model.pt': { checkpoint: true } }]
27+
}
28+
}
29+
}
30+
}
31+
32+
const rootYamlNoCheckpoints = {
33+
path: join(dvcDemoPath, 'dvc.yaml'),
34+
yaml: {
35+
stages: { train: { outs: ['model.pt'] } }
36+
}
37+
}
38+
39+
const extraYamlNoCheckpoints = {
40+
path: join(dvcDemoPath, 'extra', 'dvc.yaml'),
41+
yaml: {
42+
stages: { train: { outs: ['extra-model.pt'] } }
43+
}
44+
}
45+
46+
experimentCheckpointsModel.transformAndSet(rootYamlHasCheckpoints)
47+
expect(experimentCheckpointsModel.hasCheckpoints()).toBe(true)
48+
49+
experimentCheckpointsModel.transformAndSet(extraYamlHasCheckpoints)
50+
expect(experimentCheckpointsModel.hasCheckpoints()).toBe(true)
51+
52+
experimentCheckpointsModel.transformAndSet(rootYamlNoCheckpoints)
53+
expect(experimentCheckpointsModel.hasCheckpoints()).toBe(true)
54+
55+
experimentCheckpointsModel.transformAndSet(extraYamlNoCheckpoints)
56+
57+
expect(experimentCheckpointsModel.hasCheckpoints()).toBe(false)
58+
})
59+
})
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import { Disposable } from '@hediet/std/disposable'
2+
import { collectHasCheckpoints } from './collect'
3+
import { PartialDvcYaml } from '../../fileSystem'
4+
import { definedAndNonEmpty, uniqueValues } from '../../util/array'
5+
6+
export class ExperimentCheckpointsModel {
7+
public dispose = Disposable.fn()
8+
9+
private yamlWithCheckpoints: string[] = []
10+
11+
public hasCheckpoints() {
12+
return definedAndNonEmpty(this.yamlWithCheckpoints)
13+
}
14+
15+
public transformAndSet(data: { path: string; yaml: PartialDvcYaml }) {
16+
const { path, yaml } = data
17+
const hasCheckpoints = collectHasCheckpoints(yaml)
18+
19+
if (hasCheckpoints) {
20+
this.yamlWithCheckpoints = uniqueValues([
21+
...this.yamlWithCheckpoints,
22+
path
23+
])
24+
} else {
25+
this.yamlWithCheckpoints = this.yamlWithCheckpoints.filter(
26+
file => file !== path
27+
)
28+
}
29+
}
30+
}

extension/src/fileSystem/index.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ export const isSameOrChild = (root: string, path: string) => {
5454
return !rel.startsWith('..')
5555
}
5656

57+
export type PartialDvcYaml = {
58+
stages: {
59+
train: { outs: (string | Record<string, { checkpoint?: boolean }>)[] }
60+
}
61+
}
5762
export const isAnyDvcYaml = (path?: string): boolean =>
5863
!!(
5964
path &&

0 commit comments

Comments
 (0)