diff --git a/extension/src/experiments/checkpoints/collect.test.ts b/extension/src/experiments/checkpoints/collect.test.ts new file mode 100644 index 0000000000..f1ea1c934a --- /dev/null +++ b/extension/src/experiments/checkpoints/collect.test.ts @@ -0,0 +1,89 @@ +import { collectHasCheckpoints } from './collect' +import { PartialDvcYaml } from '../../fileSystem' + +describe('collectHasCheckpoints', () => { + it('should correctly identify that the demo project has checkpoints', () => { + const hasCheckpoints = collectHasCheckpoints({ + stages: { + train: { + cmd: 'python train.py', + deps: ['data/MNIST', 'train.py'], + live: { logs: { html: true, summary: true } }, + outs: [{ 'model.pt': { checkpoint: true } }], + params: ['seed', 'lr', 'weight_decay'], + plots: [ + 'plots', + { + 'predictions.json': { + cache: false, + template: 'confusion', + x: 'actual', + y: 'predicted' + } + } + ] + } + } + } as PartialDvcYaml) + expect(hasCheckpoints).toBe(true) + }) + + it('should correctly classify a dvc.yaml without checkpoint', () => { + const hasCheckpoints = collectHasCheckpoints({ + stages: { + extract: { + cmd: 'tar -xzf data/images.tar.gz --directory data', + deps: ['data/images.tar.gz'], + outs: [{ 'data/images/': { cache: false } }] + }, + train: { + cmd: 'python3 src/train.py', + deps: ['data/images/', 'src/train.py'], + live: { logs: { html: true, summary: true } }, + metrics: [{ 'metrics.json': { cache: false } }], + outs: ['models/model.h5'], + params: ['model.conv_units', 'train.epochs'], + plots: [{ 'logs.csv': { cache: false } }] + } + } + } as PartialDvcYaml) + + expect(hasCheckpoints).toBe(false) + }) + + it('should correctly classify a more complex dvc.yaml without checkpoint', () => { + const hasCheckpoints = collectHasCheckpoints({ + stages: { + evaluate: { + cmd: 'python src/evaluate.py model.pkl data/features scores.json prc.json roc.json', + deps: ['data/features', 'model.pkl', 'src/evaluate.py'], + metrics: [{ 'scores.json': { cache: false } }], + plots: [ + { 'prc.json': { cache: false, x: 'recall', y: 'precision' } }, + { 'roc.json': { cache: false, x: 'fpr', y: 'tpr' } } + ] + }, + featurize: { + cmd: 'python src/featurization.py data/prepared data/features', + deps: ['data/prepared', 'src/featurization.py'], + outs: ['data/features'], + params: ['featurize.max_features', 'featurize.ngrams'] + }, + prepare: { + cmd: 'python src/prepare.py data/data.xml', + deps: ['data/data.xml', 'src/prepare.py'], + outs: ['data/prepared'], + params: ['prepare.seed', 'prepare.split'] + }, + train: { + cmd: 'python src/train.py data/features model.pkl', + deps: ['data/features', 'src/train.py'], + outs: ['model.pkl'], + params: ['train.min_split', 'train.n_est', 'train.seed'] + } + } + } as PartialDvcYaml) + + expect(hasCheckpoints).toBe(false) + }) +}) diff --git a/extension/src/experiments/checkpoints/collect.ts b/extension/src/experiments/checkpoints/collect.ts new file mode 100644 index 0000000000..126021a262 --- /dev/null +++ b/extension/src/experiments/checkpoints/collect.ts @@ -0,0 +1,13 @@ +import { PartialDvcYaml } from '../../fileSystem' + +export const collectHasCheckpoints = (yaml: PartialDvcYaml): boolean => { + return !!yaml.stages.train.outs.find(out => { + if (typeof out === 'string') { + return false + } + + if (Object.values(out).find(file => file?.checkpoint)) { + return true + } + }) +} diff --git a/extension/src/experiments/checkpoints/model.test.ts b/extension/src/experiments/checkpoints/model.test.ts new file mode 100644 index 0000000000..8dc853be02 --- /dev/null +++ b/extension/src/experiments/checkpoints/model.test.ts @@ -0,0 +1,59 @@ +import { join } from 'path' +import { CheckpointsModel } from './model' +import { dvcDemoPath } from '../../test/util' + +describe('CheckpointsModel', () => { + it('should keep a record of yaml files that have checkpoints', () => { + const experimentCheckpointsModel = new CheckpointsModel() + expect(experimentCheckpointsModel.hasCheckpoints()).toBe(false) + + const rootYamlHasCheckpoints = { + path: join(dvcDemoPath, 'dvc.yaml'), + yaml: { + stages: { + train: { + outs: [{ 'model.pt': { checkpoint: true } }] + } + } + } + } + + const extraYamlHasCheckpoints = { + path: join(dvcDemoPath, 'extra', 'dvc.yaml'), + yaml: { + stages: { + train: { + outs: [{ 'extra-model.pt': { checkpoint: true } }] + } + } + } + } + + const rootYamlNoCheckpoints = { + path: join(dvcDemoPath, 'dvc.yaml'), + yaml: { + stages: { train: { outs: ['model.pt'] } } + } + } + + const extraYamlNoCheckpoints = { + path: join(dvcDemoPath, 'extra', 'dvc.yaml'), + yaml: { + stages: { train: { outs: ['extra-model.pt'] } } + } + } + + experimentCheckpointsModel.transformAndSet(rootYamlHasCheckpoints) + expect(experimentCheckpointsModel.hasCheckpoints()).toBe(true) + + experimentCheckpointsModel.transformAndSet(extraYamlHasCheckpoints) + expect(experimentCheckpointsModel.hasCheckpoints()).toBe(true) + + experimentCheckpointsModel.transformAndSet(rootYamlNoCheckpoints) + expect(experimentCheckpointsModel.hasCheckpoints()).toBe(true) + + experimentCheckpointsModel.transformAndSet(extraYamlNoCheckpoints) + + expect(experimentCheckpointsModel.hasCheckpoints()).toBe(false) + }) +}) diff --git a/extension/src/experiments/checkpoints/model.ts b/extension/src/experiments/checkpoints/model.ts new file mode 100644 index 0000000000..9f8b720576 --- /dev/null +++ b/extension/src/experiments/checkpoints/model.ts @@ -0,0 +1,30 @@ +import { Disposable } from '@hediet/std/disposable' +import { collectHasCheckpoints } from './collect' +import { PartialDvcYaml } from '../../fileSystem' +import { definedAndNonEmpty, uniqueValues } from '../../util/array' + +export class CheckpointsModel { + public dispose = Disposable.fn() + + private yamlWithCheckpoints: string[] = [] + + public hasCheckpoints() { + return definedAndNonEmpty(this.yamlWithCheckpoints) + } + + public transformAndSet(data: { path: string; yaml: PartialDvcYaml }) { + const { path, yaml } = data + const hasCheckpoints = collectHasCheckpoints(yaml) + + if (hasCheckpoints) { + this.yamlWithCheckpoints = uniqueValues([ + ...this.yamlWithCheckpoints, + path + ]) + } else { + this.yamlWithCheckpoints = this.yamlWithCheckpoints.filter( + file => file !== path + ) + } + } +} diff --git a/extension/src/fileSystem/index.ts b/extension/src/fileSystem/index.ts index 4fff187cd6..f5735e9ae3 100644 --- a/extension/src/fileSystem/index.ts +++ b/extension/src/fileSystem/index.ts @@ -54,6 +54,11 @@ export const isSameOrChild = (root: string, path: string) => { return !rel.startsWith('..') } +export type PartialDvcYaml = { + stages: { + train: { outs: (string | Record)[] } + } +} export const isAnyDvcYaml = (path?: string): boolean => !!( path &&