diff --git a/tensorboard/webapp/hparams/_redux/BUILD b/tensorboard/webapp/hparams/_redux/BUILD index 15a565d974..5073a2c3f0 100644 --- a/tensorboard/webapp/hparams/_redux/BUILD +++ b/tensorboard/webapp/hparams/_redux/BUILD @@ -65,6 +65,7 @@ tf_ts_library( ":types", ":utils", "//tensorboard/webapp/hparams:_types", + "//tensorboard/webapp/persistent_settings", "//tensorboard/webapp/runs/actions", "//tensorboard/webapp/widgets/data_table:types", "//tensorboard/webapp/widgets/data_table:utils", @@ -168,6 +169,7 @@ tf_ts_library( "//tensorboard/webapp/app_routing/actions", "//tensorboard/webapp/core/actions", "//tensorboard/webapp/hparams:types", + "//tensorboard/webapp/persistent_settings", "//tensorboard/webapp/runs/actions", "//tensorboard/webapp/runs/data_source:testing", "//tensorboard/webapp/runs/store:testing", diff --git a/tensorboard/webapp/hparams/_redux/hparams_reducers.ts b/tensorboard/webapp/hparams/_redux/hparams_reducers.ts index 5ff2897c77..eaca5cb3f0 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_reducers.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_reducers.ts @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ import {Action, ActionReducer, createReducer, on} from '@ngrx/store'; -import {Side} from '../../widgets/data_table/types'; import {DataTableUtils} from '../../widgets/data_table/utils'; +import {persistentSettingsLoaded} from '../../persistent_settings'; +import {Side} from '../../widgets/data_table/types'; import * as actions from './hparams_actions'; import {HparamsState} from './types'; @@ -33,6 +34,16 @@ const initialState: HparamsState = { const reducer: ActionReducer = createReducer( initialState, + on(persistentSettingsLoaded, (state, {partialSettings}) => { + const {dashboardDisplayedHparamColumns: storedColumns} = partialSettings; + if (storedColumns) { + return { + ...state, + dashboardDisplayedHparamColumns: storedColumns, + }; + } + return state; + }), on(actions.hparamsFetchSessionGroupsSucceeded, (state, action) => { const nextDashboardSpecs = action.hparamsAndMetricsSpecs; const nextDashboardSessionGroups = action.sessionGroups; diff --git a/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts b/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts index 895b8f6170..6ed102ac59 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts @@ -19,8 +19,81 @@ import {reducers} from './hparams_reducers'; import {buildHparamSpec, buildHparamsState, buildMetricSpec} from './testing'; import {ColumnHeaderType, Side} from '../../widgets/data_table/types'; import {DataTableUtils} from '../../widgets/data_table/utils'; +import {persistentSettingsLoaded} from '../../persistent_settings'; describe('hparams/_redux/hparams_reducers_test', () => { + describe('#persistentSettingsLoaded', () => { + it('loads dashboardDisplayedHparamColumns from the persistent settings storage', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: [], + }); + const state2 = reducers( + state, + persistentSettingsLoaded({ + partialSettings: { + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ]); + }); + + it('does nothing if persistent settings does not contain dashboardDisplayedHparamColumns', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + ], + }); + const state2 = reducers( + state, + persistentSettingsLoaded({ + partialSettings: {}, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + ]); + }); + }); + describe('hparamsFetchSessionGroupsSucceeded', () => { it('saves action.hparamsAndMetricsSpecs as dashboardSpecs', () => { const state = buildHparamsState({ diff --git a/tensorboard/webapp/runs/store/BUILD b/tensorboard/webapp/runs/store/BUILD index c0f9fa24bd..d8715bc8e8 100644 --- a/tensorboard/webapp/runs/store/BUILD +++ b/tensorboard/webapp/runs/store/BUILD @@ -99,6 +99,7 @@ tf_ts_library( ":testing", ":types", ":utils", + "//tensorboard/webapp:app_state", "//tensorboard/webapp/app_routing:testing", "//tensorboard/webapp/app_routing:types", "//tensorboard/webapp/app_routing/actions", diff --git a/tensorboard/webapp/runs/store/runs_selectors.ts b/tensorboard/webapp/runs/store/runs_selectors.ts index 5e5d091636..87ed5f73d9 100644 --- a/tensorboard/webapp/runs/store/runs_selectors.ts +++ b/tensorboard/webapp/runs/store/runs_selectors.ts @@ -330,6 +330,16 @@ export const getRunsTableSortingInfo = createSelector( export const getGroupedRunsTableHeaders = createSelector( getRunsTableHeaders, getDashboardDisplayedHparamColumns, - (runsTableHeaders, hparamColumns) => - DataTableUtils.groupColumns([...runsTableHeaders, ...hparamColumns]) + (runsTableHeaders, hparamColumns) => { + // Override hparam options to match runs table requirements. + const columns = [...runsTableHeaders, ...hparamColumns].map((column) => { + const newColumn = {...column}; + if (column.type === 'HPARAM') { + newColumn.removable = true; + newColumn.hidable = false; + } + return newColumn; + }); + return DataTableUtils.groupColumns(columns); + } ); diff --git a/tensorboard/webapp/runs/store/runs_selectors_test.ts b/tensorboard/webapp/runs/store/runs_selectors_test.ts index 3d31b4e72f..c1e083a69c 100644 --- a/tensorboard/webapp/runs/store/runs_selectors_test.ts +++ b/tensorboard/webapp/runs/store/runs_selectors_test.ts @@ -24,6 +24,7 @@ import { buildHparamSpec, } from '../../hparams/testing'; import {buildMockState} from '../../testing/utils'; +import {State} from '../../app_state'; import {DataLoadState} from '../../types/data'; import {ColumnHeaderType, SortingOrder} from '../../widgets/data_table/types'; import {GroupByKey} from '../types'; @@ -1029,8 +1030,10 @@ describe('runs_selectors', () => { }); describe('#getGroupedRunsTableHeaders', () => { - it('returns runs table headers grouped with other headers', () => { - const state = buildMockState({ + let state: State; + + beforeEach(() => { + state = buildMockState({ runs: buildRunsState( {}, { @@ -1081,38 +1084,70 @@ describe('runs_selectors', () => { }) ), }); + }); + it('returns runs table headers grouped with other headers', () => { expect(selectors.getGroupedRunsTableHeaders(state)).toEqual([ - { + jasmine.objectContaining({ type: ColumnHeaderType.RUN, name: 'run', displayName: 'Run', enabled: true, - }, - { + }), + jasmine.objectContaining({ type: ColumnHeaderType.CUSTOM, name: 'experimentAlias', displayName: 'Experiment Alias', enabled: true, - }, - { + }), + jasmine.objectContaining({ type: ColumnHeaderType.HPARAM, name: 'conv_layers', displayName: 'Conv Layers', enabled: true, - }, - { + }), + jasmine.objectContaining({ type: ColumnHeaderType.HPARAM, name: 'conv_kernel_size', displayName: 'Conv Kernel Size', enabled: true, - }, - { + }), + jasmine.objectContaining({ type: ColumnHeaderType.COLOR, name: 'color', displayName: 'Color', enabled: true, - }, + }), + ]); + }); + + it('sets the hparam column context options for the runs table', () => { + expect(selectors.getGroupedRunsTableHeaders(state)).toEqual([ + jasmine.objectContaining({ + type: ColumnHeaderType.RUN, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.CUSTOM, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + removable: true, + hidable: false, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + removable: true, + hidable: false, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.COLOR, + }), ]); }); }); diff --git a/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html b/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html index 31e6552043..d91853e2c0 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html +++ b/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html @@ -28,7 +28,6 @@ (); - @Output() addColumn = new EventEmitter<{ - header: ColumnHeader; - index?: number | undefined; - }>(); + @Output() addColumn = new EventEmitter(); @Output() removeColumn = new EventEmitter(); @Output() onSelectionDblClick = new EventEmitter(); @Output() addFilter = new EventEmitter(); diff --git a/tensorboard/webapp/runs/views/runs_table/runs_data_table_test.ts b/tensorboard/webapp/runs/views/runs_table/runs_data_table_test.ts index 0de6406a5d..a4d3bc30e2 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_data_table_test.ts +++ b/tensorboard/webapp/runs/views/runs_table/runs_data_table_test.ts @@ -140,21 +140,22 @@ describe('runs_data_table', () => { ).toBeTruthy(); }); - it('projects enabled headers plus color and selected column', () => { + it('projects headers plus color and selected column', () => { const fixture = createComponent({}); const dataTable = fixture.debugElement.query( By.directive(DataTableComponent) ); const headers = dataTable.queryAll(By.directive(HeaderCellComponent)); - expect(headers.length).toBe(4); + expect(headers.length).toBe(5); expect(headers[0].componentInstance.header.name).toEqual('selected'); expect(headers[1].componentInstance.header.name).toEqual('run'); - expect(headers[2].componentInstance.header.name).toEqual('other_header'); - expect(headers[3].componentInstance.header.name).toEqual('color'); + expect(headers[2].componentInstance.header.name).toEqual('disabled_header'); + expect(headers[3].componentInstance.header.name).toEqual('other_header'); + expect(headers[4].componentInstance.header.name).toEqual('color'); }); - it('projects content for each enabled header, selected, and color column', () => { + it('projects content for each header, selected, and color column', () => { const fixture = createComponent({ data: [{id: 'runid', run: 'run name', color: 'red', other_header: 'foo'}], }); @@ -163,11 +164,12 @@ describe('runs_data_table', () => { ); const cells = dataTable.queryAll(By.directive(ContentCellComponent)); - expect(cells.length).toBe(4); + expect(cells.length).toBe(5); expect(cells[0].componentInstance.header.name).toEqual('selected'); expect(cells[1].componentInstance.header.name).toEqual('run'); - expect(cells[2].componentInstance.header.name).toEqual('other_header'); - expect(cells[3].componentInstance.header.name).toEqual('color'); + expect(cells[2].componentInstance.header.name).toEqual('disabled_header'); + expect(cells[3].componentInstance.header.name).toEqual('other_header'); + expect(cells[4].componentInstance.header.name).toEqual('color'); }); describe('color column', () => { diff --git a/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts b/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts index 6bf276aef5..a9fbb8aacf 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts +++ b/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts @@ -22,7 +22,6 @@ import { import {createSelector, Store} from '@ngrx/store'; import {combineLatest, Observable, of, Subject} from 'rxjs'; import { - combineLatestWith, distinctUntilChanged, filter, map, @@ -44,13 +43,15 @@ import { getRunSelectorRegexFilter, getRunsLoadState, getRunsTableFullScreen, - getRunsTableHeaders, getRunsTableSortingInfo, + getGroupedRunsTableHeaders, } from '../../../selectors'; import {DataLoadState, LoadState} from '../../../types/data'; import { + AddColumnEvent, ColumnHeader, FilterAddedEvent, + ReorderColumnEvent, SortingInfo, TableData, } from '../../../widgets/data_table/types'; @@ -59,9 +60,6 @@ import { runPageSelectionToggled, runSelectionToggled, runSelectorRegexFilterChanged, - runsTableHeaderAdded, - runsTableHeaderOrderChanged, - runsTableHeaderRemoved, runsTableSortingInfoChanged, singleRunSelected, } from '../../actions'; @@ -70,7 +68,7 @@ import {RunsTableColumn, RunTableItem} from './types'; import { getCurrentColumnFilters, getFilteredRenderableRuns, - getPotentialHparamColumns, + getSelectableColumns, } from '../../../metrics/views/main_view/common_selectors'; import {runsTableFullScreenToggled} from '../../../core/actions'; import {sortTableDataItems} from './sorting_utils'; @@ -143,18 +141,9 @@ export class RunsTableContainer implements OnInit, OnDestroy { @Input() showHparamsAndMetrics = false; regexFilter$ = this.store.select(getRunSelectorRegexFilter); - runsColumns$ = this.store.select(getRunsTableHeaders); + runsColumns$ = this.store.select(getGroupedRunsTableHeaders); runsTableFullScreen$ = this.store.select(getRunsTableFullScreen); - - selectableColumns$ = this.store.select(getPotentialHparamColumns).pipe( - combineLatestWith(this.runsColumns$), - map(([potentialColumns, currentColumns]) => { - const currentColumnNames = new Set(currentColumns.map(({name}) => name)); - return potentialColumns.filter((columnHeader) => { - return !currentColumnNames.has(columnHeader.name); - }); - }) - ); + selectableColumns$ = this.store.select(getSelectableColumns); columnFilters$ = this.store.select(getCurrentColumnFilters); @@ -332,19 +321,26 @@ export class RunsTableContainer implements OnInit, OnDestroy { this.store.dispatch(runsTableFullScreenToggled()); } - addColumn({header, index}: {header: ColumnHeader; index: number}) { - header.enabled = true; + addColumn({column, nextTo, side}: AddColumnEvent) { this.store.dispatch( - runsTableHeaderAdded({header: {...header, enabled: true}, index}) + hparamsActions.dashboardHparamColumnAdded({ + column, + nextTo, + side, + }) ); } removeColumn(header: ColumnHeader) { - this.store.dispatch(runsTableHeaderRemoved({header})); + this.store.dispatch( + hparamsActions.dashboardHparamColumnRemoved({column: header}) + ); } - orderColumns(newHeaderOrder: ColumnHeader[]) { - this.store.dispatch(runsTableHeaderOrderChanged({newHeaderOrder})); + orderColumns(event: ReorderColumnEvent) { + this.store.dispatch( + hparamsActions.dashboardHparamColumnOrderChanged(event) + ); } addHparamFilter(event: FilterAddedEvent) { diff --git a/tensorboard/webapp/widgets/data_table/data_table_component.ts b/tensorboard/webapp/widgets/data_table/data_table_component.ts index dea6d2080a..09224b3c3b 100644 --- a/tensorboard/webapp/widgets/data_table/data_table_component.ts +++ b/tensorboard/webapp/widgets/data_table/data_table_component.ts @@ -36,6 +36,7 @@ import { SortingOrder, ReorderColumnEvent, Side, + AddColumnEvent, } from './types'; import {HeaderCellComponent} from './header_cell_component'; import {Subscription} from 'rxjs'; @@ -78,10 +79,7 @@ export class DataTableComponent implements OnDestroy, AfterContentInit { @Output() sortDataBy = new EventEmitter(); @Output() orderColumns = new EventEmitter(); @Output() removeColumn = new EventEmitter(); - @Output() addColumn = new EventEmitter<{ - header: ColumnHeader; - index?: number | undefined; - }>(); + @Output() addColumn = new EventEmitter(); @Output() addFilter = new EventEmitter(); @ViewChild('columnSelectorModal', {static: false}) @@ -347,7 +345,11 @@ export class DataTableComponent implements OnDestroy, AfterContentInit { } onColumnAdded(header: ColumnHeader) { - this.addColumn.emit({header, index: this.getInsertIndex()}); + this.addColumn.emit({ + column: header, + nextTo: this.contextMenuHeader, + side: this.insertColumnTo, + }); } openFilterMenu(event: MouseEvent, header: ColumnHeader) {