Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Multi File Select on Plot Wizard #4748

Merged
merged 13 commits into from
Oct 5, 2023
151 changes: 120 additions & 31 deletions extension/src/fileSystem/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {
getPidFromFile,
getEntryFromJsonFile,
addPlotToDvcYamlFile,
loadDataFile
loadDataFiles
} from '.'
import { dvcDemoPath } from '../test/util'
import { DOT_DVC } from '../cli/dvc/constants'
Expand Down Expand Up @@ -63,17 +63,22 @@ beforeEach(() => {
jest.resetAllMocks()
})

describe('loadDataFile', () => {
describe('loadDataFiles', () => {
it('should load in csv file contents', async () => {
const mockCsvContent = ['epoch,acc', '10,0.69', '11,0.345'].join('\n')

mockedReadFileSync.mockReturnValueOnce(mockCsvContent)

const result = await loadDataFile('values.csv')
const result = await loadDataFiles(['values.csv'])

expect(result).toStrictEqual([
{ acc: 0.69, epoch: 10 },
{ acc: 0.345, epoch: 11 }
{
data: [
{ acc: 0.69, epoch: 10 },
{ acc: 0.345, epoch: 11 }
],
file: 'values.csv'
}
])
})

Expand All @@ -85,11 +90,16 @@ describe('loadDataFile', () => {

mockedReadFileSync.mockReturnValueOnce(mockJsonContent)

const result = await loadDataFile('values.json')
const result = await loadDataFiles(['values.json'])

expect(result).toStrictEqual([
{ acc: 0.69, epoch: 10 },
{ acc: 0.345, epoch: 11 }
{
data: [
{ acc: 0.69, epoch: 10 },
{ acc: 0.345, epoch: 11 }
],
file: 'values.json'
}
])
})

Expand All @@ -98,11 +108,16 @@ describe('loadDataFile', () => {

mockedReadFileSync.mockReturnValueOnce(mockTsvContent)

const result = await loadDataFile('values.tsv')
const result = await loadDataFiles(['values.tsv'])

expect(result).toStrictEqual([
{ acc: 0.69, epoch: 10 },
{ acc: 0.345, epoch: 11 }
{
data: [
{ acc: 0.69, epoch: 10 },
{ acc: 0.345, epoch: 11 }
],
file: 'values.tsv'
}
])
})

Expand All @@ -115,15 +130,47 @@ describe('loadDataFile', () => {

mockedReadFileSync.mockReturnValueOnce(mockYamlContent)

const result = await loadDataFile('dvc.yaml')
const result = await loadDataFiles(['dvc.yaml'])

expect(result).toStrictEqual({
stages: {
train: {
cmd: 'python train.py'
}
expect(result).toStrictEqual([
{
data: {
stages: {
train: {
cmd: 'python train.py'
}
}
},
file: 'dvc.yaml'
}
})
])
})

it('should load in the contents of multiple files', async () => {
const mockTsvContent = ['epoch\tacc', '10\t0.69', '11\t0.345'].join('\n')
const mockCsvContent = ['epoch2,acc2', '10,0.679', '11,0.3'].join('\n')

mockedReadFileSync.mockReturnValueOnce(mockTsvContent)
mockedReadFileSync.mockReturnValueOnce(mockCsvContent)

const result = await loadDataFiles(['values.tsv', 'values2.csv'])

expect(result).toStrictEqual([
{
data: [
{ acc: 0.69, epoch: 10 },
{ acc: 0.345, epoch: 11 }
],
file: 'values.tsv'
},
{
data: [
{ acc2: 0.679, epoch2: 10 },
{ acc2: 0.3, epoch2: 11 }
],
file: 'values2.csv'
}
])
})

it('should catch any errors thrown during file parsing', async () => {
Expand All @@ -133,11 +180,29 @@ describe('loadDataFile', () => {
})

for (const file of dataFiles) {
const resultWithErr = await loadDataFile(file)
const resultWithErr = await loadDataFiles([file])

expect(resultWithErr).toStrictEqual(undefined)
}
})

it('should catch any errors thrown during the parsing of multiple files', async () => {
const dataFiles = ['values.csv', 'file.tsv', 'file.json']
const mockCsvContent = ['epoch,acc', '10,0.69', '11,0.345'].join('\n')
const mockJsonContent = JSON.stringify([
{ acc: 0.69, epoch: 10 },
{ acc: 0.345, epoch: 11 }
])
mockedReadFileSync
.mockReturnValueOnce(mockCsvContent)
.mockImplementationOnce(() => {
throw new Error('fake error')
})
.mockReturnValueOnce(mockJsonContent)

const resultWithErr = await loadDataFiles(dataFiles)
expect(resultWithErr).toStrictEqual(undefined)
})
})

describe('writeJson', () => {
Expand Down Expand Up @@ -541,10 +606,37 @@ describe('addPlotToDvcYamlFile', () => {
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent)

addPlotToDvcYamlFile('/', {
dataFile: '/data.json',
template: 'simple',
x: 'epochs',
y: 'accuracy'
x: { file: '/data.json', key: 'epochs' },
y: { file: '/data.json', key: 'accuracy' }
})

expect(mockedWriteFileSync).toHaveBeenCalledWith(
'//dvc.yaml',
mockDvcYamlContent + mockPlotYamlContent
)
})

it('should add the new plot with fields coming from different files', () => {
const mockDvcYamlContent = mockStagesLines.join('\n')
const mockPlotYamlContent = [
'',
'plots:',
' - simple_plot:',
' template: simple',
' x:',
' data.json: epochs',
' y:',
' acc.json: accuracy',
''
].join('\n')
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent)
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent)

addPlotToDvcYamlFile('/', {
template: 'simple',
x: { file: '/data.json', key: 'epochs' },
y: { file: '/acc.json', key: 'accuracy' }
})

expect(mockedWriteFileSync).toHaveBeenCalledWith(
Expand All @@ -560,10 +652,9 @@ describe('addPlotToDvcYamlFile', () => {
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent.join('\n'))

addPlotToDvcYamlFile('/', {
dataFile: '/data.json',
template: 'simple',
x: 'epochs',
y: 'accuracy'
x: { file: '/data.json', key: 'epochs' },
y: { file: '/data.json', key: 'accuracy' }
})

mockDvcYamlContent.splice(7, 0, ...mockPlotYamlContent)
Expand All @@ -583,10 +674,9 @@ describe('addPlotToDvcYamlFile', () => {
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent)

addPlotToDvcYamlFile('/', {
dataFile: '/data.json',
template: 'simple',
x: 'epochs',
y: 'accuracy'
x: { file: '/data.json', key: 'epochs' },
y: { file: '/data.json', key: 'accuracy' }
})

expect(mockedWriteFileSync).toHaveBeenCalledWith(
Expand Down Expand Up @@ -620,10 +710,9 @@ describe('addPlotToDvcYamlFile', () => {
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent)

addPlotToDvcYamlFile('/', {
dataFile: '/data.json',
template: 'simple',
x: 'epochs',
y: 'accuracy'
x: { file: '/data.json', key: 'epochs' },
y: { file: '/data.json', key: 'accuracy' }
})

expect(mockedWriteFileSync).toHaveBeenCalledWith(
Expand Down
41 changes: 37 additions & 4 deletions extension/src/fileSystem/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -214,21 +214,38 @@ const loadYamlAsDoc = (
}
}

const getPlotYamlObj = (cwd: string, plot: PlotConfigData) => {
const { x, y, template } = plot
const usesSingleFile = x.file === y.file
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check for multiple files being used and create the plot object accordingly:

plots:
  # two files
  - scatter_plot:
      template: scatter
      x:
        props.json: acc
      y:
        values.json: prob
  # single files
  - probs.json:
      x: actual
      y: prob

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Q] Do we need this split? Can we make the creation of plots entries a bit more opinionated from the wizard and reduce complexity? Otherwise, when we come to add in custom titles we will have to complicate the differentiation further.

From our demo dvc.yaml:

  - Loss:
      x: step
      y:
        training/plots/metrics/train/loss.tsv: loss
        training/plots/metrics/test/loss.tsv: loss
      y_label: loss
  - Confusion matrix:
      template: confusion
      x: actual
      y:
        training/plots/sklearn/confusion_matrix.json: predicted
  - hist.csv:
      x: preds
      y: digit
      template: bar_horizontal
      title: Histogram of Predictions

Maybe we want to stick to the first two entry types.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the two plot types to be more similar:

plots:
  # two files
  - scatter_plot:
      template: scatter
      x:
        props.json: acc
      y:
        values.json: prob
  # single files
  - simple_plot:
      template: simple
      x: actual
      y:
        probs.json: prob

if (usesSingleFile) {
const dataFile = x.file
const plotName = relative(cwd, dataFile)
return { [plotName]: { template, x: x.key, y: y.key } }
}

const plotName = `${template}_plot`
return {
[plotName]: {
template,
x: { [relative(cwd, x.file)]: x.key },
y: { [relative(cwd, y.file)]: y.key }
}
}
}

const getPlotsYaml = (
cwd: string,
plotObj: PlotConfigData,
indentSearchLines: string[]
) => {
const { dataFile, ...plot } = plotObj
const plotName = relative(cwd, dataFile)
const indentReg = /^( +)[^ ]/
const indentLine = indentSearchLines.find(line => indentReg.test(line)) || ''
const spacesMatches = indentLine.match(indentReg)
const spaces = spacesMatches?.[1]

return yaml
.stringify(
{ plots: [{ [plotName]: plot }] },
{ plots: [getPlotYamlObj(cwd, plotObj)] },
{ indent: spaces ? spaces.length : 2 }
)
.split('\n')
Expand Down Expand Up @@ -315,7 +332,7 @@ const loadTsv = (path: string) => {
}
}

export const loadDataFile = (file: string): unknown => {
const loadDataFile = (file: string): unknown => {
const ext = getFileExtension(file)

switch (ext) {
Expand All @@ -330,6 +347,22 @@ export const loadDataFile = (file: string): unknown => {
}
}

export const loadDataFiles = async (
files: string[]
): Promise<{ file: string; data: unknown }[] | undefined> => {
const filesData: { file: string; data: unknown }[] = []
for (const file of files) {
const data = await loadDataFile(file)

if (!data) {
return undefined
}

filesData.push({ data, file })
}
return filesData
}

export const writeJson = <
T extends Record<string, unknown> | Array<Record<string, unknown>>
>(
Expand Down
2 changes: 1 addition & 1 deletion extension/src/pipeline/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ export class Pipeline extends DeferredDisposable {
return
}

const plotConfiguration = await pickPlotConfiguration()
const plotConfiguration = await pickPlotConfiguration(cwd)

if (!plotConfiguration) {
return
Expand Down
Loading
Loading