diff --git a/tensorboard/webapp/metrics/internal_types.ts b/tensorboard/webapp/metrics/internal_types.ts index 43b6fc2ba5..102e42323b 100644 --- a/tensorboard/webapp/metrics/internal_types.ts +++ b/tensorboard/webapp/metrics/internal_types.ts @@ -89,6 +89,7 @@ export interface URLDeserializedState { metrics: { pinnedCards: CardUniqueInfo[]; smoothing: number | null; + tagFilter: string | null; }; } diff --git a/tensorboard/webapp/metrics/store/BUILD b/tensorboard/webapp/metrics/store/BUILD index 326d83341b..b140ff3a02 100644 --- a/tensorboard/webapp/metrics/store/BUILD +++ b/tensorboard/webapp/metrics/store/BUILD @@ -92,6 +92,7 @@ tf_ts_library( "//tensorboard/webapp/metrics/actions", "//tensorboard/webapp/metrics/data_source", "//tensorboard/webapp/persistent_settings", + "//tensorboard/webapp/routes:testing", "//tensorboard/webapp/types", "//tensorboard/webapp/util:dom", "@npm//@types/jasmine", diff --git a/tensorboard/webapp/metrics/store/metrics_reducers.ts b/tensorboard/webapp/metrics/store/metrics_reducers.ts index a91e21c82f..068624c1b8 100644 --- a/tensorboard/webapp/metrics/store/metrics_reducers.ts +++ b/tensorboard/webapp/metrics/store/metrics_reducers.ts @@ -347,11 +347,16 @@ const reducer = createReducer( }; } - return { + const newState = { ...state, ...resolvedResult, settingOverrides: newSettings, }; + + if (hydratedState.metrics.tagFilter !== null) { + newState.tagFilter = hydratedState.metrics.tagFilter; + } + return newState; }), on(globalSettingsLoaded, (state, {partialSettings}) => { const metricsSettings: Partial = {}; diff --git a/tensorboard/webapp/metrics/store/metrics_reducers_test.ts b/tensorboard/webapp/metrics/store/metrics_reducers_test.ts index 104aae109e..f2645258e3 100644 --- a/tensorboard/webapp/metrics/store/metrics_reducers_test.ts +++ b/tensorboard/webapp/metrics/store/metrics_reducers_test.ts @@ -17,6 +17,7 @@ import {buildRoute} from '../../app_routing/testing'; import {RouteKind} from '../../app_routing/types'; import * as coreActions from '../../core/actions'; import {globalSettingsLoaded} from '../../persistent_settings'; +import {buildDeserializedState} from '../../routes/testing'; import {DataLoadState} from '../../types/data'; import {nextElementId} from '../../util/dom'; import * as actions from '../actions'; @@ -1778,6 +1779,47 @@ describe('metrics reducers', () => { }); }); + describe('tag filter hydration', () => { + it('rehydrates the value', () => { + const beforeState = buildMetricsState({tagFilter: 'foo'}); + const action = routingActions.stateRehydratedFromUrl({ + routeKind: RouteKind.EXPERIMENT, + partialState: { + metrics: {...buildDeserializedState().metrics, tagFilter: 'bar'}, + }, + }); + const nextState = reducers(beforeState, action); + + expect(nextState.tagFilter).toBe('bar'); + }); + + it('rehydrates an empty string value', () => { + const beforeState = buildMetricsState({tagFilter: 'foo'}); + const action = routingActions.stateRehydratedFromUrl({ + routeKind: RouteKind.EXPERIMENT, + partialState: { + metrics: {...buildDeserializedState().metrics, tagFilter: ''}, + }, + }); + const nextState = reducers(beforeState, action); + + expect(nextState.tagFilter).toBe(''); + }); + + it('does not hydrate when the value is null', () => { + const beforeState = buildMetricsState({tagFilter: 'foo'}); + const action = routingActions.stateRehydratedFromUrl({ + routeKind: RouteKind.EXPERIMENT, + partialState: { + metrics: {...buildDeserializedState().metrics, tagFilter: null}, + }, + }); + const nextState = reducers(beforeState, action); + + expect(nextState.tagFilter).toBe('foo'); + }); + }); + describe('#globalSettingsLoaded', () => { it('adds partial state from loading the settings to the (default) settings', () => { const beforeState = buildMetricsState({ diff --git a/tensorboard/webapp/routes/BUILD b/tensorboard/webapp/routes/BUILD index 077d657298..367a8d9f5a 100644 --- a/tensorboard/webapp/routes/BUILD +++ b/tensorboard/webapp/routes/BUILD @@ -52,6 +52,17 @@ tf_ts_library( ], ) +tf_ts_library( + name = "testing", + testonly = True, + srcs = [ + "testing.ts", + ], + deps = [ + ":dashboard_deeplink_provider_types", + ], +) + tf_ts_library( name = "routes_test_lib", testonly = True, @@ -60,6 +71,7 @@ tf_ts_library( ], deps = [ ":dashboard_deeplink_provider", + ":testing", "//tensorboard/webapp:app_state", "//tensorboard/webapp:selectors", "//tensorboard/webapp/angular:expect_angular_core_testing", diff --git a/tensorboard/webapp/routes/dashboard_deeplink_provider.ts b/tensorboard/webapp/routes/dashboard_deeplink_provider.ts index 2d4e1c5bb4..04e96e3481 100644 --- a/tensorboard/webapp/routes/dashboard_deeplink_provider.ts +++ b/tensorboard/webapp/routes/dashboard_deeplink_provider.ts @@ -38,6 +38,7 @@ import { DeserializedState, PINNED_CARDS_KEY, SMOOTHING_KEY, + TAG_FILTER_KEY, } from './dashboard_deeplink_provider_types'; const COLOR_GROUP_REGEX_VALUE_PREFIX = 'regex:'; @@ -117,6 +118,14 @@ export class DashboardDeepLinkProvider extends DeepLinkProvider { ): Observable { return combineLatest([ this.getMetricsPinnedCards(store), + store.select(selectors.getMetricsTagFilter).pipe( + map((filterText) => { + if (!filterText) { + return []; + } + return [{key: TAG_FILTER_KEY, value: filterText}]; + }) + ), this.getFeatureFlagStates(store), store.select(selectors.getMetricsSettingOverrides).pipe( map((settingOverrides) => { @@ -165,6 +174,7 @@ export class DashboardDeepLinkProvider extends DeepLinkProvider { ): DeserializedState { let pinnedCards = null; let smoothing = null; + let tagFilter = null; let groupBy: GroupBy | null = null; for (const {key, value} of queryParams) { @@ -193,6 +203,9 @@ export class DashboardDeepLinkProvider extends DeepLinkProvider { } break; } + case TAG_FILTER_KEY: + tagFilter = value; + break; } } @@ -200,6 +213,7 @@ export class DashboardDeepLinkProvider extends DeepLinkProvider { metrics: { pinnedCards: pinnedCards || [], smoothing, + tagFilter, }, runs: { groupBy, diff --git a/tensorboard/webapp/routes/dashboard_deeplink_provider_test.ts b/tensorboard/webapp/routes/dashboard_deeplink_provider_test.ts index 383d199c34..f02ab42176 100644 --- a/tensorboard/webapp/routes/dashboard_deeplink_provider_test.ts +++ b/tensorboard/webapp/routes/dashboard_deeplink_provider_test.ts @@ -24,6 +24,7 @@ import {appStateFromMetricsState, buildMetricsState} from '../metrics/testing'; import * as selectors from '../selectors'; import {DashboardDeepLinkProvider} from './dashboard_deeplink_provider'; import {GroupBy, GroupByKey} from '../runs/types'; +import {buildDeserializedState} from './testing'; describe('core deeplink provider', () => { let store: MockStore; @@ -206,17 +207,21 @@ describe('core deeplink provider', () => { }, ]); - expect(state.metrics).toEqual({ - pinnedCards: [ - {plugin: PluginType.SCALARS, tag: 'accuracy'}, - { - plugin: PluginType.IMAGES, - tag: 'loss', - runId: 'exp1/123', - sample: 5, - }, - ], - smoothing: null, + const defaultState = buildDeserializedState(); + expect(state).toEqual({ + ...defaultState, + metrics: { + ...defaultState.metrics, + pinnedCards: [ + {plugin: PluginType.SCALARS, tag: 'accuracy'}, + { + plugin: PluginType.IMAGES, + tag: 'loss', + runId: 'exp1/123', + sample: 5, + }, + ], + }, }); }); @@ -303,6 +308,43 @@ describe('core deeplink provider', () => { } }); }); + + describe('tag filter', () => { + it('serializes the filter text to the URL', () => { + store.overrideSelector(selectors.getMetricsTagFilter, 'accuracy'); + store.refreshState(); + + expect(queryParamsSerialized.slice(-1)[0]).toEqual([ + {key: 'tagFilter', value: 'accuracy'}, + ]); + }); + + it('does not serialize an empty string', () => { + store.overrideSelector(selectors.getMetricsTagFilter, ''); + store.refreshState(); + + expect(queryParamsSerialized).toEqual([]); + }); + + it('deserializes the string from the URL', () => { + const state1 = provider.deserializeQueryParams([ + {key: 'tagFilter', value: 'accuracy'}, + ]); + expect(state1.metrics.tagFilter).toBe('accuracy'); + }); + + it('deserializes the empty string from the URL', () => { + const state1 = provider.deserializeQueryParams([ + {key: 'tagFilter', value: ''}, + ]); + expect(state1.metrics.tagFilter).toBe(''); + }); + + it('deserializes to null when no value is provided', () => { + const state = provider.deserializeQueryParams([]); + expect(state.metrics.tagFilter).toBe(null); + }); + }); }); describe('feature flag', () => { diff --git a/tensorboard/webapp/routes/dashboard_deeplink_provider_types.ts b/tensorboard/webapp/routes/dashboard_deeplink_provider_types.ts index e50ca8ad4e..5acd77b7bb 100644 --- a/tensorboard/webapp/routes/dashboard_deeplink_provider_types.ts +++ b/tensorboard/webapp/routes/dashboard_deeplink_provider_types.ts @@ -25,3 +25,5 @@ export const SMOOTHING_KEY = 'smoothing'; export const PINNED_CARDS_KEY = 'pinnedCards'; export const RUN_COLOR_GROUP_KEY = 'runColorGroup'; + +export const TAG_FILTER_KEY = 'tagFilter'; diff --git a/tensorboard/webapp/routes/testing.ts b/tensorboard/webapp/routes/testing.ts new file mode 100644 index 0000000000..cd35d2053c --- /dev/null +++ b/tensorboard/webapp/routes/testing.ts @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {DeserializedState} from './dashboard_deeplink_provider_types'; + +export function buildDeserializedState( + override: Partial = {} +) { + return { + runs: { + groupBy: null, + }, + metrics: { + pinnedCards: [], + smoothing: null, + tagFilter: null, + }, + ...override, + }; +}