diff --git a/.vscode/settings.json b/.vscode/settings.json index 43c28f4e41..f35eed42cf 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -45,6 +45,7 @@ "unaddable", "uncommit", "uniqwith", + "unmerge", "unprotect", "unshallow", "unstage", diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index c3e1552cb9..682a512ba2 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -33,6 +33,13 @@ import { TemplateOrder } from '../paths/collect' import { extendVegaSpec, isMultiViewPlot } from '../vega/util' import { definedAndNonEmpty, splitMatchedOrdered } from '../../util/array' import { shortenForLabel } from '../../util/string' +import { + getDvcDataVersionInfo, + isConcatenatedField, + mergeFields, + MultiSourceEncoding, + unmergeConcatenatedFields +} from '../multiSource/collect' type CheckpointPlotAccumulator = { iterations: Record @@ -355,7 +362,14 @@ const collectDatapoints = ( values: Record[] = [] ) => { for (const value of values) { - ;(acc[rev][path] as unknown[]).push({ ...value, rev }) + const dvc_data_version_info = getDvcDataVersionInfo(value) + const data: { rev: string } = { + ...value, + ...dvc_data_version_info, + rev + } + + ;(acc[rev][path] as unknown[]).push(data) } } @@ -487,10 +501,33 @@ export const collectTemplates = (data: PlotsOutput): TemplateAccumulator => { return acc } -const fillTemplate = (template: string, datapoints: unknown[]) => - JSON.parse( - template.replace('""', JSON.stringify(datapoints)) +const fillTemplate = ( + template: string, + datapoints: unknown[], + field?: string +) => { + if (!field || !isConcatenatedField(field)) { + return JSON.parse( + template.replace('""', JSON.stringify(datapoints)) + ) as TopLevelSpec + } + + const fields = unmergeConcatenatedFields(field) + return JSON.parse( + template.replace( + '""', + JSON.stringify( + datapoints.map(data => { + const obj = data as Record + return { + ...obj, + [field]: mergeFields(fields.map(field => obj[field] as string)) + } + }) + ) + ) ) as TopLevelSpec +} const collectTemplateGroup = ( paths: string[], @@ -498,7 +535,8 @@ const collectTemplateGroup = ( templates: TemplateAccumulator, revisionData: RevisionData, size: PlotSize, - revisionColors: ColorScale | undefined + revisionColors: ColorScale | undefined, + multiSourceEncoding: MultiSourceEncoding ): TemplatePlotEntry[] => { const acc: TemplatePlotEntry[] = [] for (const path of paths) { @@ -509,10 +547,19 @@ const collectTemplateGroup = ( .flatMap(revision => revisionData?.[revision]?.[path]) .filter(Boolean) + const multiSourceEncodingUpdate = multiSourceEncoding[path] || {} + const content = extendVegaSpec( - fillTemplate(template, datapoints), + fillTemplate( + template, + datapoints, + multiSourceEncodingUpdate.strokeDash?.field + ), size, - revisionColors + { + ...multiSourceEncodingUpdate, + color: revisionColors + } ) acc.push({ @@ -533,7 +580,8 @@ export const collectSelectedTemplatePlots = ( templates: TemplateAccumulator, revisionData: RevisionData, size: PlotSize, - revisionColors: ColorScale | undefined + revisionColors: ColorScale | undefined, + multiSourceEncoding: MultiSourceEncoding ): TemplatePlotSection[] | undefined => { const acc: TemplatePlotSection[] = [] for (const templateGroup of order) { @@ -544,7 +592,8 @@ export const collectSelectedTemplatePlots = ( templates, revisionData, size, - revisionColors + revisionColors, + multiSourceEncoding ) if (!definedAndNonEmpty(entries)) { continue diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index 7d8637554f..3e034519d4 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -33,6 +33,12 @@ import { removeMissingKeysFromObject } from '../../util/object' import { TemplateOrder } from '../paths/collect' import { PersistenceKey } from '../../persistence/constants' import { ModelWithPersistence } from '../../persistence/model' +import { + collectMultiSourceEncoding, + collectMultiSourceVariations, + MultiSourceEncoding, + MultiSourceVariations +} from '../multiSource/collect' export class PlotsModel extends ModelWithPersistence { private readonly experiments: Experiments @@ -49,6 +55,8 @@ export class PlotsModel extends ModelWithPersistence { private revisionData: RevisionData = {} private templates: TemplateAccumulator = {} + private multiSourceVariations: MultiSourceVariations = {} + private multiSourceEncoding: MultiSourceEncoding = {} private checkpointPlots?: CheckpointPlot[] private selectedMetrics?: string[] @@ -104,10 +112,12 @@ export class PlotsModel extends ModelWithPersistence { ...revs.map(rev => cliIdToLabel[rev]) ]) - const [{ comparisonData, revisionData }, templates] = await Promise.all([ - collectData(data, cliIdToLabel), - collectTemplates(data) - ]) + const [{ comparisonData, revisionData }, templates, multiSourceVariations] = + await Promise.all([ + collectData(data, cliIdToLabel), + collectTemplates(data), + collectMultiSourceVariations(data, this.multiSourceVariations) + ]) const { overwriteComparisonData, overwriteRevisionData } = collectWorkspaceRaceConditionData( @@ -127,6 +137,10 @@ export class PlotsModel extends ModelWithPersistence { ...overwriteRevisionData } this.templates = { ...this.templates, ...templates } + this.multiSourceVariations = multiSourceVariations + this.multiSourceEncoding = collectMultiSourceEncoding( + this.multiSourceVariations + ) this.setComparisonOrder() @@ -422,7 +436,8 @@ export class PlotsModel extends ModelWithPersistence { this.templates, this.revisionData, this.getPlotSize(Section.TEMPLATE_PLOTS), - this.getRevisionColors() + this.getRevisionColors(), + this.multiSourceEncoding ) } } diff --git a/extension/src/plots/multiSource/collect.test.ts b/extension/src/plots/multiSource/collect.test.ts new file mode 100644 index 0000000000..7ccb3a102c --- /dev/null +++ b/extension/src/plots/multiSource/collect.test.ts @@ -0,0 +1,168 @@ +import { join } from 'path' +import { collectMultiSourceEncoding } from './collect' + +describe('collectMultiSourceEncoding', () => { + it('should return an empty object given a single variation collected from the datapoints', () => { + const multiSourceEncoding = collectMultiSourceEncoding({ + path: [{ field: 'x', filename: 'path' }] + }) + expect(multiSourceEncoding).toStrictEqual({}) + }) + + it('should return an object containing a filename strokeDash given variations with differing filenames', () => { + const otherPath = join('other', 'path') + const multiSourceEncoding = collectMultiSourceEncoding({ + combined: [ + { field: 'x', filename: 'path' }, + { field: 'x', filename: otherPath } + ] + }) + expect(multiSourceEncoding).toStrictEqual({ + combined: { + strokeDash: { + field: 'filename', + scale: { + domain: [otherPath, 'path'], + range: [ + [1, 0], + [8, 8] + ] + } + } + } + }) + }) + + it('should return an object containing a field strokeDash given variations with differing fields', () => { + const multiSourceEncoding = collectMultiSourceEncoding({ + path: [ + { field: 'x', filename: 'path' }, + { field: 'z', filename: 'path' } + ] + }) + expect(multiSourceEncoding).toStrictEqual({ + path: { + strokeDash: { + field: 'field', + scale: { + domain: ['x', 'z'], + range: [ + [1, 0], + [8, 8] + ] + } + } + } + }) + }) + + it('should return an object containing a merged filename::field strokeDash given variations with differing filename and fields', () => { + const otherPath = join('other', 'path') + const multiSourceEncoding = collectMultiSourceEncoding({ + combined: [ + { field: 'x', filename: 'path' }, + { field: 'z', filename: otherPath } + ] + }) + expect(multiSourceEncoding).toStrictEqual({ + combined: { + strokeDash: { + field: 'filename::field', + scale: { + domain: [`${otherPath}::z`, 'path::x'], + range: [ + [1, 0], + [8, 8] + ] + } + } + } + }) + }) + + it('should return an object containing a merged filename::field strokeDash given variations with differing filename and similar field', () => { + const multiSourceEncoding = collectMultiSourceEncoding({ + combined: [ + { field: 'x', filename: join('first', 'path') }, + { field: 'z', filename: join('second', 'path') }, + { field: 'z', filename: join('third', 'path') } + ] + }) + expect(multiSourceEncoding).toStrictEqual({ + combined: { + strokeDash: { + field: 'filename::field', + scale: { + domain: [ + `${join('first', 'path')}::x`, + `${join('second', 'path')}::z`, + `${join('third', 'path')}::z` + ], + range: [ + [1, 0], + [8, 8], + [8, 4] + ] + } + } + } + }) + }) + + it('should return an object containing a merged filename::field strokeDash given variations with differing filename and field for each variation', () => { + const multiSourceEncoding = collectMultiSourceEncoding({ + combined: [ + { field: 'x', filename: join('first', 'path') }, + { field: 'z', filename: join('second', 'path') }, + { field: 'q', filename: join('third', 'path') } + ] + }) + expect(multiSourceEncoding).toStrictEqual({ + combined: { + strokeDash: { + field: 'filename::field', + scale: { + domain: [ + `${join('first', 'path')}::x`, + `${join('second', 'path')}::z`, + `${join('third', 'path')}::q` + ], + range: [ + [1, 0], + [8, 8], + [8, 4] + ] + } + } + } + }) + }) + + it('should return an object containing a filename strokeDash and field shape given variations with unmergable combinations of filename and field', () => { + const multiSourceEncoding = collectMultiSourceEncoding({ + combined: [ + { field: 'x', filename: 'path' }, + { field: 'z', filename: 'path' }, + { field: 'z', filename: join('other', 'path') } + ] + }) + expect(multiSourceEncoding).toStrictEqual({ + combined: { + shape: { + field: 'field', + scale: { domain: ['x', 'z'], range: ['square', 'circle'] } + }, + strokeDash: { + field: 'filename', + scale: { + domain: [join('other', 'path'), 'path'], + range: [ + [1, 0], + [8, 8] + ] + } + } + } + }) + }) +}) diff --git a/extension/src/plots/multiSource/collect.ts b/extension/src/plots/multiSource/collect.ts new file mode 100644 index 0000000000..ef6523e7e1 --- /dev/null +++ b/extension/src/plots/multiSource/collect.ts @@ -0,0 +1,325 @@ +import isEqual from 'lodash.isequal' +import { + ShapeEncoding, + StrokeDashEncoding, + StrokeDashScale, + StrokeDash, + Shape, + StrokeDashValue, + ShapeValue +} from './constants' +import { isImagePlot, Plot, TemplatePlot } from '../webview/contract' +import { PlotsOutput } from '../../cli/dvc/reader' + +const FIELD_SEPARATOR = '::' + +type Values = { filename?: Set; field?: Set } + +type Variation = { + filename?: string | undefined + field?: string | undefined +} +type Variations = Variation[] + +export type MultiSourceVariations = Record + +export type MultiSourceEncoding = Record< + string, + { + strokeDash: StrokeDashEncoding + shape?: ShapeEncoding + } +> + +export const mergeFields = (fields: string[]): string => + fields.join(FIELD_SEPARATOR) + +export const isConcatenatedField = (field: string): boolean => + !!field.includes(FIELD_SEPARATOR) + +export const unmergeConcatenatedFields = (field: string): string[] => + field.split(FIELD_SEPARATOR) + +export const getDvcDataVersionInfo = ( + value: Record +): Record => { + const dvcDataVersionInfo = + (value.dvc_data_version_info as Record) || {} + if (dvcDataVersionInfo.revision) { + delete dvcDataVersionInfo.revision + } + + return dvcDataVersionInfo +} + +const collectPlotMultiSourceVariations = ( + acc: Record[]>, + path: string, + plot: TemplatePlot +) => { + if (!acc[path]) { + acc[path] = [] + } + + for (const value of Object.values(plot.datapoints || {}).flat()) { + const dvcDataVersionInfo = getDvcDataVersionInfo(value) + + if (acc[path].some(obj => isEqual(obj, dvcDataVersionInfo))) { + continue + } + acc[path].push(dvcDataVersionInfo) + } +} + +const collectPathMultiSourceVariations = ( + acc: Record[]>, + path: string, + plots: Plot[] +) => { + for (const plot of plots) { + if (isImagePlot(plot)) { + continue + } + + collectPlotMultiSourceVariations(acc, path, plot) + } +} + +export const collectMultiSourceVariations = ( + data: PlotsOutput, + acc: Record[]> +) => { + for (const [path, plots] of Object.entries(data)) { + collectPathMultiSourceVariations(acc, path, plots) + } + + return acc +} + +const initializeAcc = (values: Values, key: 'filename' | 'field') => { + if (!values[key]) { + values[key] = new Set() + } +} + +const collectMultiSourceValue = (values: Values, variation: Variation) => { + for (const [key, value] of Object.entries(variation)) { + if (key !== 'filename' && key !== 'field') { + continue + } + initializeAcc(values, key) + ;(values[key] as Set).add(value) + } +} + +const collectMultiSourceValues = (variations: Variations): Values => { + const values: Values = {} + for (const variation of variations) { + collectMultiSourceValue(values, variation) + } + return values +} + +const sortDifferentVariations = ( + differentVariations: { field: string; variations: number }[], + expectedOrder: string[] +): { field: string; variations: number }[] => { + differentVariations.sort( + ( + { field: aField, variations: aVariations }, + { field: bField, variations: bVariations } + ) => + aVariations === bVariations + ? expectedOrder.indexOf(aField) - expectedOrder.indexOf(bField) + : aVariations - bVariations + ) + return differentVariations +} + +const groupVariations = ( + variations: Variations, + values: Values +): { + lessValuesThanVariations: { field: string; variations: number }[] + valuesMatchVariations: string[] +} => { + const valuesMatchVariations: string[] = [] + const lessValuesThanVariations: { field: string; variations: number }[] = [] + + for (const [field, valueSet] of Object.entries(values)) { + if (valueSet.size === 1) { + continue + } + if (valueSet.size === variations.length) { + valuesMatchVariations.push(field) + continue + } + lessValuesThanVariations.push({ field, variations: valueSet.size }) + } + + const expectedOrder = ['filename', 'field'] + return { + lessValuesThanVariations: sortDifferentVariations( + lessValuesThanVariations, + expectedOrder + ), + valuesMatchVariations: valuesMatchVariations.sort( + (a, b) => expectedOrder.indexOf(a) - expectedOrder.indexOf(b) + ) + } +} + +const collectVariation = ( + scale: StrokeDashScale, + keysToCombine: string[], + field: Set, + idx: number, + variation: Variation +): void => { + const domain: string[] = [] + + for (const key of keysToCombine as (keyof typeof variation)[]) { + if (variation[key]) { + domain.push(variation[key] as string) + field.add(key) + } + } + + scale.domain.push(mergeFields(domain)) + scale.range.push(StrokeDash[idx]) + scale.domain.sort() +} + +const getEncoding = ( + field: Set, + scale: { domain: string[]; range: T[] } +): { field: string; scale: { domain: string[]; range: T[] } } => ({ + field: mergeFields([...field]), + scale +}) + +const collectMergedStrokeDashEncoding = ( + acc: MultiSourceEncoding, + path: string, + variations: Variations, + keysToCombine: string[] +): void => { + const scale: StrokeDashScale = { + domain: [], + range: [] + } + let idx = 0 + const field = new Set() + + for (const variation of variations) { + collectVariation(scale, keysToCombine, field, idx, variation) + idx++ + } + + acc[path] = { + strokeDash: getEncoding(field, scale) + } +} + +const collectEncodingFromValues = ( + scaleRange: T, + values: Values, + lessValuesThanVariations: { field: string; variations: number }[] +): { field: string; scale: { range: T[number][]; domain: string[] } } => { + const scale: { range: T[number][]; domain: string[] } = { + domain: [], + range: [] + } + const filenameOrField = lessValuesThanVariations.shift() + let idx = 0 + const field = new Set() + if (filenameOrField?.field) { + for (const value of values[filenameOrField.field as 'filename' | 'field'] || + []) { + field.add(filenameOrField.field) + scale.domain.push(value) + scale.range.push(scaleRange[idx]) + scale.domain.sort() + idx++ + } + } + return getEncoding(field, scale) +} + +const collectUnmergedStrokeDashEncoding = ( + acc: MultiSourceEncoding, + path: string, + values: Values, + lessValuesThanVariations: { field: string; variations: number }[] +): void => { + acc[path] = { + strokeDash: collectEncodingFromValues( + StrokeDash, + values, + lessValuesThanVariations + ) + } +} + +const collectUnmergedShapeEncoding = ( + acc: MultiSourceEncoding, + path: string, + values: Values, + lessValuesThanVariations: { field: string; variations: number }[] +): void => { + acc[path] = { + ...acc[path], + shape: collectEncodingFromValues(Shape, values, lessValuesThanVariations) + } +} + +const collectPathMultiSourceEncoding = ( + acc: MultiSourceEncoding, + path: string, + variations: Variations +): void => { + const values = collectMultiSourceValues(variations) + + const { valuesMatchVariations, lessValuesThanVariations } = groupVariations( + variations, + values + ) + + if (valuesMatchVariations.length > 0) { + const keysToCombined = [ + ...valuesMatchVariations, + ...lessValuesThanVariations.map(({ field }) => field) + ] + collectMergedStrokeDashEncoding(acc, path, variations, keysToCombined) + return + } + + if (lessValuesThanVariations.length > 0 && !acc[path]?.strokeDash) { + collectUnmergedStrokeDashEncoding( + acc, + path, + values, + lessValuesThanVariations + ) + } + + if (lessValuesThanVariations.length > 0) { + collectUnmergedShapeEncoding(acc, path, values, lessValuesThanVariations) + } +} + +export const collectMultiSourceEncoding = ( + data: Record +): MultiSourceEncoding => { + const acc: MultiSourceEncoding = {} + + for (const [path, variations] of Object.entries(data)) { + if (variations.length <= 1) { + continue + } + + collectPathMultiSourceEncoding(acc, path, variations) + } + + return acc +} diff --git a/extension/src/plots/multiSource/constants.ts b/extension/src/plots/multiSource/constants.ts new file mode 100644 index 0000000000..725a811cd9 --- /dev/null +++ b/extension/src/plots/multiSource/constants.ts @@ -0,0 +1,34 @@ +export const StrokeDash = [ + [1, 0], + [8, 8], + [8, 4], + [4, 4], + [4, 2], + [2, 1], + [1, 1] +] as const +export type StrokeDashValue = typeof StrokeDash[number] + +export const Shape = [ + 'square', + 'circle', + 'triangle', + 'diamond', + 'cross' +] as const +export type ShapeValue = typeof Shape[number] + +export type Scale = { + domain: string[] + range: T[] +} + +export type Encoding = { + scale: Scale +} & { field: string } + +export type StrokeDashScale = Scale +export type StrokeDashEncoding = Encoding + +export type ShapeScale = Scale +export type ShapeEncoding = Encoding diff --git a/extension/src/plots/vega/util.test.ts b/extension/src/plots/vega/util.test.ts index 9ecdd9e883..b4994f93c4 100644 --- a/extension/src/plots/vega/util.test.ts +++ b/extension/src/plots/vega/util.test.ts @@ -1,10 +1,13 @@ import { Text as VegaText, Title as VegaTitle } from 'vega' import { TopLevelSpec } from 'vega-lite' +import merge from 'lodash.merge' import { isMultiViewPlot, isMultiViewByCommitPlot, extendVegaSpec, - getColorScale + getColorScale, + Encoding, + reverseOfLegendSuppressionUpdate } from './util' import confusionTemplate from '../../test/fixtures/plotsDiff/templates/confusion' import confusionNormalizedTemplate from '../../test/fixtures/plotsDiff/templates/confusionNormalized' @@ -87,11 +90,9 @@ describe('extendVegaSpec', () => { domain: ['workspace', 'main'], range: copyOriginalColors().slice(0, 2) } - const extendedSpec = extendVegaSpec( - linearTemplate, - PlotSize.REGULAR, - colorScale - ) + const extendedSpec = extendVegaSpec(linearTemplate, PlotSize.REGULAR, { + color: colorScale + }) expect(extendedSpec).not.toStrictEqual(defaultTemplate) expect(extendedSpec.encoding.color).toStrictEqual({ @@ -264,3 +265,44 @@ describe('extendVegaSpec', () => { expect(updatedSpecString).toContain(truncatedTitle) }) }) + +describe('reverseOfLegendSuppressionUpdate', () => { + it('should reverse the legend suppression applied by extendVegaSpec', () => { + type NonOptionalEncoding = { [P in keyof Encoding]-?: Encoding[P] } + const update: NonOptionalEncoding = { + color: { + legend: { + disable: true + }, + scale: { domain: [], range: [] } + }, + detail: { + field: 'shape-field' + }, + shape: { + field: 'shape-field', + legend: { + disable: true + }, + scale: { domain: [], range: [] } + }, + strokeDash: { + field: 'strokeDash-field', + legend: { + disable: true + }, + scale: { domain: [], range: [] } + } + } + + expect(JSON.stringify(update)).toContain('"legend":{"disable":true}') + + const reverse = reverseOfLegendSuppressionUpdate() + + const result = JSON.stringify( + merge({ spec: { encoding: update } }, reverse) + ) + expect(result).not.toContain('"legend":{"disable":true}') + expect(result).toContain('"legend":{"disable":false}') + }) +}) diff --git a/extension/src/plots/vega/util.ts b/extension/src/plots/vega/util.ts index 4f6bb7c2db..c6aa688f23 100644 --- a/extension/src/plots/vega/util.ts +++ b/extension/src/plots/vega/util.ts @@ -20,6 +20,7 @@ import { } from 'vega-lite/build/src/spec/repeat' import { TopLevelUnitSpec } from 'vega-lite/build/src/spec/unit' import { ColorScale, PlotSize, Revision } from '../webview/contract' +import { ShapeEncoding, StrokeDashEncoding } from '../multiSource/constants' const COMMIT_FIELD = 'rev' @@ -98,27 +99,67 @@ export const getColorScale = ( return acc.domain.length > 0 ? acc : undefined } -type EncodingUpdate = { - encoding: { - color: { - legend: { - disable: boolean - } - scale: ColorScale +export type Encoding = { + strokeDash?: StrokeDashEncoding & { + legend: { + disable: boolean + } + } + shape?: ShapeEncoding & { + legend: { + disable: boolean + } + } + detail?: { + field: string + } + color?: { + legend: { + disable: boolean } + scale: ColorScale } } -export const getSpecEncodingUpdate = ( - colorScale: ColorScale -): EncodingUpdate => ({ - encoding: { - color: { +type EncodingUpdate = { + encoding: Encoding +} + +export const getSpecEncodingUpdate = ({ + color, + shape, + strokeDash +}: { + color?: ColorScale + shape?: ShapeEncoding + strokeDash?: StrokeDashEncoding +}): EncodingUpdate => { + const encoding: Encoding = {} + if (color) { + encoding.color = { legend: { disable: true }, - scale: colorScale + scale: color } } -}) + + if (strokeDash) { + encoding.strokeDash = { + ...strokeDash, + legend: { disable: true } + } + } + if (shape) { + encoding.shape = { + ...shape, + legend: { disable: true } + } + encoding.detail = { field: shape.field } + } + + return { + encoding + } +} const mergeUpdate = (spec: TopLevelSpec, update: EncodingUpdate) => { let newSpec = cloneDeep(spec) as any @@ -230,15 +271,29 @@ export const truncateTitles = ( export const extendVegaSpec = ( spec: TopLevelSpec, size: PlotSize, - colorScale?: ColorScale + encoding?: { + color?: ColorScale + strokeDash?: StrokeDashEncoding + shape?: ShapeEncoding + } ) => { const updatedSpec = truncateTitles(spec, size) as unknown as TopLevelSpec - if (isMultiViewByCommitPlot(spec) || !colorScale) { + if (isMultiViewByCommitPlot(spec) || !encoding) { return updatedSpec } - const update = getSpecEncodingUpdate(colorScale) + const update = getSpecEncodingUpdate(encoding) return mergeUpdate(updatedSpec, update) } + +export const reverseOfLegendSuppressionUpdate = () => ({ + spec: { + encoding: { + color: { legend: { disable: false } }, + shape: { legend: { disable: false } }, + strokeDash: { legend: { disable: false } } + } + } +}) diff --git a/extension/src/test/fixtures/plotsDiff/index.ts b/extension/src/test/fixtures/plotsDiff/index.ts index 8d610e9d01..005f77129b 100644 --- a/extension/src/test/fixtures/plotsDiff/index.ts +++ b/extension/src/test/fixtures/plotsDiff/index.ts @@ -480,8 +480,10 @@ const extendedSpecs = (plotsOutput: TemplatePlots): TemplatePlotSection[] => { } as TopLevelSpec, PlotSize.REGULAR, { - domain: expectedRevisions, - range: copyOriginalColors().slice(0, 5) + color: { + domain: expectedRevisions, + range: copyOriginalColors().slice(0, 5) + } } ) as VisualizationSpec, id: path, diff --git a/turbo.json b/turbo.json index 0de227f545..c326b087bb 100644 --- a/turbo.json +++ b/turbo.json @@ -48,6 +48,7 @@ }, "globalDependencies": [ "extension/src/test/fixtures/**", - "extension/src/experiments/columns/constants.ts" + "extension/src/experiments/columns/constants.ts", + "extension/src/plots/vega/util" ] } diff --git a/webview/src/plots/components/ZoomedInPlot.tsx b/webview/src/plots/components/ZoomedInPlot.tsx index 128b8117eb..8ad92c5c9e 100644 --- a/webview/src/plots/components/ZoomedInPlot.tsx +++ b/webview/src/plots/components/ZoomedInPlot.tsx @@ -3,6 +3,7 @@ import VegaLite, { VegaLiteProps } from 'react-vega/lib/VegaLite' import { Config } from 'vega-lite' import merge from 'lodash.merge' import cloneDeep from 'lodash.clonedeep' +import { reverseOfLegendSuppressionUpdate } from 'dvc/src/plots/vega/util' import styles from './styles.module.scss' import { getThemeValue, ThemeProperty } from '../../util/styles' @@ -25,12 +26,7 @@ export const ZoomedInPlot: React.FC = ({ return (