diff --git a/tensorboard/webapp/hparams/_redux/BUILD b/tensorboard/webapp/hparams/_redux/BUILD index cc17a3751c..b7f7cf23c4 100644 --- a/tensorboard/webapp/hparams/_redux/BUILD +++ b/tensorboard/webapp/hparams/_redux/BUILD @@ -74,6 +74,7 @@ tf_ts_library( ":hparams_actions", ":types", ":utils", + "//tensorboard/webapp:app_state", "//tensorboard/webapp/hparams:types", "//tensorboard/webapp/runs/actions", "@npm//@ngrx/store", diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts index 56244ab76b..fb3410fc5f 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts @@ -25,81 +25,126 @@ import { const getHparamsState = createFeatureSelector(HPARAMS_FEATURE_KEY); -const getHparamsDefaultFiltersForExperiments = createSelector( - getHparamsState, - ( - state: HparamsState, - experimentIds: string[] - ): Map => { - const defaultFilterMaps: Array< - Map - > = []; - - for (const experimentId of experimentIds) { - if (!state.specs[experimentId]) { - continue; - } - - defaultFilterMaps.push(state.specs[experimentId].hparam.defaultFilters); +function getHparamsDefaultFiltersForExperimentsResultFunction( + state: HparamsState, + experimentIds: string[] +): Map { + const defaultFilterMaps: Array> = + []; + + for (const experimentId of experimentIds) { + if (!state.specs[experimentId]) { + continue; } - return combineDefaultHparamFilters(defaultFilterMaps); + defaultFilterMaps.push(state.specs[experimentId].hparam.defaultFilters); } -); -export const getHparamFilterMap = createSelector( - getHparamsDefaultFiltersForExperiments, + return combineDefaultHparamFilters(defaultFilterMaps); +} + +const getHparamsDefaultFiltersForExperiments = createSelector( getHparamsState, - ( - combinedDefaultfilterMap, - hparamState, - experimentIds: string[] - ): Map => { - const id = getIdFromExperimentIds(experimentIds); - const otherFilter = hparamState.filters[id]; - - return new Map([ - ...combinedDefaultfilterMap, - ...(otherFilter?.hparams ?? []), - ]); - } + getHparamsDefaultFiltersForExperimentsResultFunction ); -const getMetricsDefaultFiltersForExperiments = createSelector( - getHparamsState, - ( - state: HparamsState, - experimentIds: string[] - ): Map => { - const defaultFilterMaps: Array> = []; +function getHparamsDefaultFiltersForExperimentsFromExperimentIds( + experimentIds: string[] +) { + return createSelector(getHparamsState, (state) => + getHparamsDefaultFiltersForExperimentsResultFunction(state, experimentIds) + ); +} + +function getHparamFilterMapResultFunction( + hparamState: HparamsState, + combinedDefaultfilterMap: Map, + experimentIds: string[] +): Map { + const id = getIdFromExperimentIds(experimentIds); + const otherFilter = hparamState.filters[id]; + + return new Map([ + ...combinedDefaultfilterMap, + ...(otherFilter?.hparams ?? []), + ]); +} - for (const experimentId of experimentIds) { - if (!state.specs[experimentId]) { - continue; - } +export const getHparamFilterMap = createSelector( + getHparamsState, + getHparamsDefaultFiltersForExperiments, + getHparamFilterMapResultFunction +); - defaultFilterMaps.push(state.specs[experimentId].metric.defaultFilters); +export function getHparamFilterMapFromExperimentIds(experimentIds: string[]) { + return createSelector( + getHparamsState, + getHparamsDefaultFiltersForExperimentsFromExperimentIds(experimentIds), + (hparamState, combinedDefaultFilterMap) => + getHparamFilterMapResultFunction( + hparamState, + combinedDefaultFilterMap, + experimentIds + ) + ); +} + +function getMetricsDefaultFiltersForExperimentsResultFunction( + state: HparamsState, + experimentIds: string[] +): Map { + const defaultFilterMaps: Array> = []; + + for (const experimentId of experimentIds) { + if (!state.specs[experimentId]) { + continue; } - return combineDefaultMetricFilters(defaultFilterMaps); + defaultFilterMaps.push(state.specs[experimentId].metric.defaultFilters); } + + return combineDefaultMetricFilters(defaultFilterMaps); +} + +const getMetricsDefaultFiltersForExperiments = createSelector( + getHparamsState, + getMetricsDefaultFiltersForExperimentsResultFunction ); +function getMetricsDefaultFiltersForExperimentsFromExperimentIds( + experimentIds: string[] +) { + return createSelector(getHparamsState, (state) => + getMetricsDefaultFiltersForExperimentsResultFunction(state, experimentIds) + ); +} + +function getMetricFilterMapResultFunction( + hparamState: HparamsState, + defaultFilterMap: Map, + experimentIds: string[] +): Map { + const id = getIdFromExperimentIds(experimentIds); + const otherFilter = hparamState.filters[id]; + + return new Map([...defaultFilterMap, ...(otherFilter?.metrics ?? [])]); +} + export const getMetricFilterMap = createSelector( - getMetricsDefaultFiltersForExperiments, getHparamsState, - ( - defaultfilterMap, - hparamState, - experimentIds: string[] - ): Map => { - const id = getIdFromExperimentIds(experimentIds); - const otherFilter = hparamState.filters[id]; - - return new Map([...defaultfilterMap, ...(otherFilter?.metrics ?? [])]); - } + getMetricsDefaultFiltersForExperiments, + getMetricFilterMapResultFunction ); +export function getMetricFilterMapFromExperimentIds(experimentIds: string[]) { + return createSelector( + getHparamsState, + getMetricsDefaultFiltersForExperimentsFromExperimentIds(experimentIds), + (state, defaultFilterMap) => + getMetricFilterMapResultFunction(state, defaultFilterMap, experimentIds) + ); +} + /** * Returns Observable that emits hparams and metrics specs of experiments. */ diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts index 9f2b8d39e2..ce620c394b 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts @@ -117,6 +117,98 @@ describe('hparams/_redux/hparams_selectors_test', () => { }); }); + describe('#getHparamFilterMapFromExperimentIds()', () => { + it('returns default hparam filter map', () => { + const state = buildStateFromHparamsState( + buildHparamsState( + buildSpecs('foo', { + hparam: { + specs: [buildHparamSpec({name: 'optimizer'})], + defaultFilters: new Map([ + [ + 'optimizer', + buildDiscreteFilter({ + filterValues: ['a', 'b', 'c'], + }), + ], + ]), + }, + }) + ) + ); + + expect( + selectors.getHparamFilterMapFromExperimentIds(['foo'])(state) + ).toEqual( + new Map([ + ['optimizer', buildDiscreteFilter({filterValues: ['a', 'b', 'c']})], + ]) + ); + }); + + it('returns custom hparam filter map', () => { + const state = buildStateFromHparamsState( + buildHparamsState( + buildSpecs('foo', { + hparam: { + specs: [buildHparamSpec({name: 'optimizer'})], + defaultFilters: new Map([ + [ + 'optimizer', + buildDiscreteFilter({ + filterValues: ['a', 'b', 'c'], + }), + ], + ]), + }, + }), + buildFilterState(['foo'], { + hparams: new Map([ + [ + 'optimizer', + buildDiscreteFilter({ + filterValues: ['d', 'e', 'f'], + }), + ], + ]), + }) + ) + ); + + expect( + selectors.getHparamFilterMapFromExperimentIds(['foo'])(state) + ).toEqual( + new Map([ + [ + 'optimizer', + buildDiscreteFilter({ + filterValues: ['d', 'e', 'f'], + }), + ], + ]) + ); + }); + + it('returns empty map for an unknown exp', () => { + const state = buildStateFromHparamsState( + buildHparamsState( + buildSpecs('foo', { + hparam: { + specs: [buildHparamSpec({name: 'optimizer'})], + defaultFilters: new Map([ + ['optimizer', buildDiscreteFilter({filterValues: ['a']})], + ]), + }, + }) + ) + ); + + expect( + selectors.getHparamFilterMapFromExperimentIds(['bar'])(state) + ).toEqual(new Map()); + }); + }); + describe('#getMetricFilterMap', () => { beforeEach(() => { // Clear the memoization. @@ -232,4 +324,121 @@ describe('hparams/_redux/hparams_selectors_test', () => { expect(selectors.getMetricFilterMap(state, ['bar'])).toEqual(new Map()); }); }); + + describe('#getMetricFilterMapFromExperimentIds', () => { + it('returns default metric filter map', () => { + const state = buildStateFromHparamsState( + buildHparamsState( + buildSpecs('foo', { + metric: { + specs: [buildMetricSpec({tag: 'acc'})], + defaultFilters: new Map([ + [ + 'acc', + buildIntervalFilter({ + filterLowerValue: 0, + filterUpperValue: 1, + }), + ], + ]), + }, + }) + ) + ); + + expect( + selectors.getMetricFilterMapFromExperimentIds(['foo'])(state) + ).toEqual( + new Map([ + [ + 'acc', + buildIntervalFilter({ + filterLowerValue: 0, + filterUpperValue: 1, + }), + ], + ]) + ); + }); + + it('returns custom metric filter map', () => { + const state = buildStateFromHparamsState( + buildHparamsState( + buildSpecs('foo', { + metric: { + specs: [buildMetricSpec({tag: 'acc'})], + defaultFilters: new Map([ + [ + 'acc', + buildIntervalFilter({ + filterLowerValue: 0, + filterUpperValue: 1, + }), + ], + ]), + }, + }), + buildFilterState(['foo'], { + metrics: new Map([ + [ + 'acc', + buildIntervalFilter({ + filterLowerValue: 0, + filterUpperValue: 0.1, + }), + ], + ]), + }) + ) + ); + expect( + selectors.getMetricFilterMapFromExperimentIds(['foo'])(state) + ).toEqual( + new Map([ + [ + 'acc', + buildIntervalFilter({ + filterLowerValue: 0, + filterUpperValue: 0.1, + }), + ], + ]) + ); + }); + + it('returns empty map for an unknown exp', () => { + const state = buildStateFromHparamsState( + buildHparamsState( + buildSpecs('foo', { + metric: { + specs: [buildMetricSpec({tag: 'acc'})], + defaultFilters: new Map([ + [ + 'acc', + buildIntervalFilter({ + filterLowerValue: 0, + filterUpperValue: 1, + }), + ], + ]), + }, + }), + buildFilterState(['foo'], { + metrics: new Map([ + [ + 'acc', + buildIntervalFilter({ + filterLowerValue: 0, + filterUpperValue: 0.1, + }), + ], + ]), + }) + ) + ); + expect( + selectors.getMetricFilterMapFromExperimentIds(['bar'])(state) + ).toEqual(new Map()); + }); + }); });