diff --git a/tensorboard/webapp/hparams/_redux/BUILD b/tensorboard/webapp/hparams/_redux/BUILD index 2cb12086fd..2b3608bd5b 100644 --- a/tensorboard/webapp/hparams/_redux/BUILD +++ b/tensorboard/webapp/hparams/_redux/BUILD @@ -50,6 +50,7 @@ tf_ts_library( deps = [ ":types", "//tensorboard/webapp/hparams:types", + "//tensorboard/webapp/widgets/data_table:types", "@npm//@ngrx/store", ], ) @@ -65,6 +66,7 @@ tf_ts_library( ":utils", "//tensorboard/webapp/hparams:_types", "//tensorboard/webapp/runs/actions", + "//tensorboard/webapp/widgets/data_table:types", "@npm//@ngrx/store", ], ) @@ -171,6 +173,7 @@ tf_ts_library( "//tensorboard/webapp/testing:utils", "//tensorboard/webapp/util:types", "//tensorboard/webapp/webapp_data_source:http_client_testing", + "//tensorboard/webapp/widgets/data_table:types", "@npm//@ngrx/effects", "@npm//@ngrx/store", "@npm//@types/jasmine", diff --git a/tensorboard/webapp/hparams/_redux/hparams_actions.ts b/tensorboard/webapp/hparams/_redux/hparams_actions.ts index 60c68cbbcd..95a83e7aee 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_actions.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_actions.ts @@ -17,7 +17,11 @@ limitations under the License. */ import {createAction, props} from '@ngrx/store'; -import {HparamAndMetricSpec, SessionGroup} from '../types'; +import { + AddColumnEvent, + ReorderColumnEvent, +} from '../../widgets/data_table/types'; +import {HparamAndMetricSpec, SessionGroup, ColumnHeader} from '../types'; import {HparamFilter, MetricFilter} from './types'; export const hparamsFetchSessionGroupsSucceeded = createAction( @@ -47,3 +51,23 @@ export const dashboardMetricFilterRemoved = createAction( '[Hparams] Dashboard Metric Filter Removed', props<{name: string}>() ); + +export const dashboardHparamColumnAdded = createAction( + '[Hparams] Dashboard Hparam Column Added', + props() +); + +export const dashboardHparamColumnRemoved = createAction( + '[Hparams] Dashboard Hparam Column Removed', + props<{column: ColumnHeader}>() +); + +export const dashboardHparamColumnToggled = createAction( + '[Hparams] Dashboard Hparam Column Toggled', + props<{column: ColumnHeader}>() +); + +export const dashboardHparamColumnOrderChanged = createAction( + '[Hparams] Dashboard Hparam Column Order Changed', + props() +); diff --git a/tensorboard/webapp/hparams/_redux/hparams_reducers.ts b/tensorboard/webapp/hparams/_redux/hparams_reducers.ts index aba10fff32..be0afde330 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_reducers.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_reducers.ts @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ import {Action, ActionReducer, createReducer, on} from '@ngrx/store'; +import {ColumnHeader, Side} from '../../widgets/data_table/types'; import * as actions from './hparams_actions'; import {HparamsState} from './types'; @@ -26,6 +27,7 @@ const initialState: HparamsState = { hparams: new Map(), metrics: new Map(), }, + dashboardDisplayedHparamColumns: [], }; const reducer: ActionReducer = createReducer( @@ -87,7 +89,82 @@ const reducer: ActionReducer = createReducer( metrics: nextMetricFilters, }, }; - }) + }), + on(actions.dashboardHparamColumnAdded, (state, {column, nextTo, side}) => { + const {dashboardDisplayedHparamColumns: oldColumns} = state; + + let destinationIndex = oldColumns.length; // Default to append at end. + if (nextTo !== undefined && side !== undefined) { + const nextToIndex = oldColumns.findIndex( + (col) => col.name === nextTo.name + ); + if (nextToIndex !== -1) { + destinationIndex = side === Side.RIGHT ? nextToIndex + 1 : nextToIndex; + } + } + let newColumn = {...column, enabled: true}; + let newColumns = [...oldColumns]; + newColumns.splice(destinationIndex, 0, newColumn); + + return { + ...state, + dashboardDisplayedHparamColumns: newColumns, + }; + }), + on(actions.dashboardHparamColumnRemoved, (state, {column}) => { + const newColumns = state.dashboardDisplayedHparamColumns.filter( + ({name}) => name !== column.name + ); + + return { + ...state, + dashboardDisplayedHparamColumns: newColumns, + }; + }), + on(actions.dashboardHparamColumnToggled, (state, {column: toggledColumn}) => { + const newColumns = state.dashboardDisplayedHparamColumns.map((column) => { + if (column.name === toggledColumn.name) { + return { + ...column, + enabled: !toggledColumn.enabled, + }; + } else { + return column; + } + }); + + return { + ...state, + dashboardDisplayedHparamColumns: newColumns, + }; + }), + on( + actions.dashboardHparamColumnOrderChanged, + (state, {source, destination, side}) => { + const {dashboardDisplayedHparamColumns: columns} = state; + let sourceIndex = columns.findIndex( + (column: ColumnHeader) => column.name === source.name + ); + let destinationIndex = columns.findIndex( + (column: ColumnHeader) => column.name === destination.name + ); + if (sourceIndex === -1 || sourceIndex === destinationIndex) { + return state; + } + if (destinationIndex === -1) { + destinationIndex = side === Side.LEFT ? 0 : columns.length - 1; + } + + const newColumns = [...columns]; + newColumns.splice(sourceIndex, 1); + newColumns.splice(destinationIndex, 0, source); + + return { + ...state, + dashboardDisplayedHparamColumns: newColumns, + }; + } + ) ); export function reducers(state: HparamsState | undefined, action: Action) { diff --git a/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts b/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts index 867a6ffc56..63aa4da507 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts @@ -17,6 +17,7 @@ import {DomainType, RunStatus} from '../types'; import * as actions from './hparams_actions'; import {reducers} from './hparams_reducers'; import {buildHparamSpec, buildHparamsState, buildMetricSpec} from './testing'; +import {ColumnHeaderType, Side} from '../../widgets/data_table/types'; describe('hparams/_redux/hparams_reducers_test', () => { describe('hparamsFetchSessionGroupsSucceeded', () => { @@ -309,4 +310,424 @@ describe('hparams/_redux/hparams_reducers_test', () => { }); }); }); + + describe('dashboardHparamColumnAdded', () => { + const fakeColumns = [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ]; + + it('appends an hparam column to the end', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: fakeColumns, + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnAdded({ + column: { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + ...fakeColumns, + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + ]); + }); + + it('inserts an hparam column to the left of an existing column', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: fakeColumns, + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnAdded({ + column: { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + nextTo: { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + side: Side.LEFT, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + ...fakeColumns, + ]); + }); + + it('inserts an hparam column to the right of an existing column', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: fakeColumns, + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnAdded({ + column: { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + nextTo: { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + side: Side.RIGHT, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + fakeColumns[0], + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + ...fakeColumns.slice(1), + ]); + }); + + it('appends an hparam column at the end if nextTo is not found', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: fakeColumns, + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnAdded({ + column: { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + nextTo: { + type: ColumnHeaderType.HPARAM, + name: 'nonexistent_layer', + displayName: 'Nonexistent layer', + enabled: true, + }, + side: Side.RIGHT, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + ...fakeColumns, + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + ]); + }); + }); + + describe('dashboardHparamColumnRemoved', () => { + it('removes an existing column', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnRemoved({ + column: { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ]); + }); + }); + + describe('dashboardHparamColumnToggled', () => { + it('enables a disabled column', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: false, + }, + ], + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnToggled({ + column: { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: false, + }, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + ]); + }); + + it('disables an enabled column', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + ], + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnToggled({ + column: { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: false, + }, + ]); + }); + }); + + describe('dashboardHparamColumnOrderChanged', () => { + const fakeColumns = [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + ]; + + it('does nothing if source is not found', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: fakeColumns, + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnOrderChanged({ + source: { + type: ColumnHeaderType.HPARAM, + name: 'nonexistent_param', + displayName: 'Nonexistent param', + enabled: false, + }, + destination: { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + side: Side.LEFT, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual(fakeColumns); + }); + + it('does nothing if source equals dest', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: fakeColumns, + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnOrderChanged({ + source: { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: false, + }, + destination: { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + side: Side.LEFT, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual(fakeColumns); + }); + + it('moves source to front if dest not found and side is left', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: fakeColumns, + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnOrderChanged({ + source: fakeColumns[1], + destination: { + type: ColumnHeaderType.HPARAM, + name: 'nonexistent param', + displayName: 'Nonexistent param', + enabled: true, + }, + side: Side.LEFT, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + fakeColumns[1], + ...fakeColumns.slice(0, 1), + ...fakeColumns.slice(2), + ]); + }); + + it('moves source to back if dest not found and side is right', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: fakeColumns, + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnOrderChanged({ + source: fakeColumns[0], + destination: { + type: ColumnHeaderType.HPARAM, + name: 'nonexistent param', + displayName: 'Nonexistent param', + enabled: true, + }, + side: Side.RIGHT, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + ...fakeColumns.slice(1), + fakeColumns[0], + ]); + }); + + it('inserts source to the left of dest when side is left', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: fakeColumns, + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnOrderChanged({ + source: fakeColumns[1], + destination: fakeColumns[0], + side: Side.LEFT, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + fakeColumns[1], + ...fakeColumns.slice(0, 1), + ...fakeColumns.slice(2), + ]); + }); + + it('inserts source to the right of dest when side is right', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: fakeColumns, + }); + const state2 = reducers( + state, + actions.dashboardHparamColumnOrderChanged({ + source: fakeColumns[0], + destination: fakeColumns[1], + side: Side.LEFT, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + fakeColumns[1], + ...fakeColumns.slice(0, 1), + ...fakeColumns.slice(2), + ]); + }); + }); }); diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts index 3ffd782b12..90c8b360d0 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts @@ -46,6 +46,11 @@ export const getDashboardDefaultHparamFilters = createSelector( } ); +export const getDashboardDisplayedHparamColumns = createSelector( + getHparamsState, + (state) => state.dashboardDisplayedHparamColumns +); + export const getDashboardHparamFilterMap = createSelector( getHparamsState, (state) => { diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts index 7732edab57..cc43c0dfa2 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +import {ColumnHeaderType} from '../../widgets/data_table/types'; import {DomainType} from '../types'; import * as selectors from './hparams_selectors'; import { @@ -111,4 +112,32 @@ describe('hparams/_redux/hparams_selectors_test', () => { ); }); }); + + describe('#getDashboardDisplayedHparamColumns', () => { + it('returns dashboard displayed hparam columns', () => { + const fakeColumns = [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ]; + const state = buildStateFromHparamsState( + buildHparamsState({ + dashboardDisplayedHparamColumns: fakeColumns, + }) + ); + + expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual( + fakeColumns + ); + }); + }); }); diff --git a/tensorboard/webapp/hparams/_redux/testing.ts b/tensorboard/webapp/hparams/_redux/testing.ts index d0867f040e..fb69c2d2ed 100644 --- a/tensorboard/webapp/hparams/_redux/testing.ts +++ b/tensorboard/webapp/hparams/_redux/testing.ts @@ -41,6 +41,8 @@ export function buildHparamsState( hparams: overrides.dashboardFilters?.hparams ?? new Map(), metrics: overrides.dashboardFilters?.metrics ?? new Map(), }, + dashboardDisplayedHparamColumns: + overrides.dashboardDisplayedHparamColumns ?? [], } as HparamsState; } diff --git a/tensorboard/webapp/hparams/_redux/types.ts b/tensorboard/webapp/hparams/_redux/types.ts index c3735d9450..ddada2e410 100644 --- a/tensorboard/webapp/hparams/_redux/types.ts +++ b/tensorboard/webapp/hparams/_redux/types.ts @@ -15,6 +15,7 @@ limitations under the License. import { DiscreteFilter, IntervalFilter, + ColumnHeader, HparamAndMetricSpec, SessionGroup, } from '../_types'; @@ -34,6 +35,7 @@ export interface HparamsState { hparams: Map; metrics: Map; }; + dashboardDisplayedHparamColumns: ColumnHeader[]; } export interface State { diff --git a/tensorboard/webapp/hparams/_types.ts b/tensorboard/webapp/hparams/_types.ts index e007983d4e..a1f3e91023 100644 --- a/tensorboard/webapp/hparams/_types.ts +++ b/tensorboard/webapp/hparams/_types.ts @@ -17,7 +17,11 @@ import { MetricSpec, } from '../runs/data_source/runs_data_source_types'; -export {DiscreteFilter, IntervalFilter} from '../widgets/data_table/types'; +export { + DiscreteFilter, + IntervalFilter, + ColumnHeader, +} from '../widgets/data_table/types'; export { DatasetType, diff --git a/tensorboard/webapp/widgets/data_table/types.ts b/tensorboard/webapp/widgets/data_table/types.ts index 9cffb03c99..976f2782d4 100644 --- a/tensorboard/webapp/widgets/data_table/types.ts +++ b/tensorboard/webapp/widgets/data_table/types.ts @@ -111,3 +111,20 @@ export enum DataTableMode { SINGLE, RANGE, } + +export enum Side { + RIGHT, + LEFT, +} + +export interface ReorderColumnEvent { + source: ColumnHeader; + destination: ColumnHeader; + side: Side; +} + +export interface AddColumnEvent { + column: ColumnHeader; + nextTo?: ColumnHeader | undefined; + side?: Side | undefined; +}