diff --git a/tensorboard/plugins/graph/tf_graph/tf-graph.ts b/tensorboard/plugins/graph/tf_graph/tf-graph.ts index 8420c10064..96c69c877a 100644 --- a/tensorboard/plugins/graph/tf_graph/tf-graph.ts +++ b/tensorboard/plugins/graph/tf_graph/tf-graph.ts @@ -95,7 +95,7 @@ class TfGraph extends LegacyElementMixin(PolymerElement) { @property({type: Object}) devicesForStats: object; @property({type: Object}) - hierarchyParams: any; + hierarchyParams: tf_graph_hierarchy.HierarchyParams; @property({ type: Object, notify: true, @@ -479,9 +479,6 @@ class TfGraph extends LegacyElementMixin(PolymerElement) { this.nodeToggleSeriesGroup(nodeName); } nodeToggleSeriesGroup(nodeName) { - // Toggle the group setting of the specified node appropriately. - tf_graph.toggleNodeSeriesGroup(this.hierarchyParams.seriesMap, nodeName); - // Rebuild the render hierarchy with the updated series grouping map. this.set('progress', { value: 0, msg: '', @@ -492,8 +489,15 @@ class TfGraph extends LegacyElementMixin(PolymerElement) { 100, 'Namespace hierarchy' ); + + // Toggle the node's group type, setting to 'UNGROUP' if unspecified. + const newHierarchyParams = { + ...this.hierarchyParams, + seriesMap: this.graphHierarchy.buildSeriesGroupMapToggled(nodeName), + }; + tf_graph_hierarchy - .build(this.basicGraph, this.hierarchyParams, hierarchyTracker) + .build(this.basicGraph, newHierarchyParams, hierarchyTracker) .then( function (graphHierarchy) { this.set('graphHierarchy', graphHierarchy); diff --git a/tensorboard/plugins/graph/tf_graph_board/tf-graph-board.ts b/tensorboard/plugins/graph/tf_graph_board/tf-graph-board.ts index 35f7e260fa..a7840b9aef 100644 --- a/tensorboard/plugins/graph/tf_graph_board/tf-graph-board.ts +++ b/tensorboard/plugins/graph/tf_graph_board/tf-graph-board.ts @@ -22,7 +22,7 @@ import * as tf_graph from '../tf_graph_common/graph'; import * as tf_graph_render from '../tf_graph_common/render'; import {LegacyElementMixin} from '../../../components/polymer/legacy_element_mixin'; import {ColorBy} from '../tf_graph_common/view_types'; -import {Hierarchy} from '../tf_graph_common/hierarchy'; +import * as tf_graph_hierarchy from '../tf_graph_common/hierarchy'; /** * Some UX features, such as 'color by structure', rely on the 'template' @@ -174,7 +174,6 @@ class TfGraphBoard extends LegacyElementMixin(PolymerElement) { id="graph-info" title="selected" graph-hierarchy="[[graphHierarchy]]" - hierarchy-params="[[hierarchyParams]]" render-hierarchy="[[renderHierarchy]]" graph="[[graph]]" selected-node="{{selectedNode}}" @@ -197,9 +196,14 @@ class TfGraphBoard extends LegacyElementMixin(PolymerElement) { `; @property({type: Object}) - graphHierarchy: Hierarchy; + graphHierarchy: tf_graph_hierarchy.Hierarchy; @property({type: Object}) graph: tf_graph.SlimGraph; + // TODO(psybuzz): ideally, this would be a required property and the component + // that owns and the graph loader should create these params. + @property({type: Object}) + hierarchyParams: tf_graph_hierarchy.HierarchyParams = + tf_graph_hierarchy.DefaultHierarchyParams; @property({type: Object}) stats: object; /** diff --git a/tensorboard/plugins/graph/tf_graph_common/common_test.ts b/tensorboard/plugins/graph/tf_graph_common/common_test.ts index 93e720251f..845624959c 100644 --- a/tensorboard/plugins/graph/tf_graph_common/common_test.ts +++ b/tensorboard/plugins/graph/tf_graph_common/common_test.ts @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ import * as tb_debug from '../../../components/tb_debug'; +import * as tf_graph from './graph'; +import * as tf_graph_hierarchy from './hierarchy'; import * as tf_graph_loader from './loader'; import * as tf_graph_parser from './parser'; @@ -127,4 +129,68 @@ describe('graph tests', () => { }, ]); }); + + describe('Graph Hierarchy', () => { + it('does not mutate the provided seriesMap by reference', () => { + const readonlySeriesMap = new Map([ + ['fooNode', tf_graph.SeriesGroupingType.UNGROUP], + ['barNode', tf_graph.SeriesGroupingType.GROUP], + ]); + const params = { + ...tf_graph_hierarchy.DefaultHierarchyParams, + seriesMap: readonlySeriesMap, + }; + + const hierarchy = new tf_graph_hierarchy.Hierarchy(params); + hierarchy.setSeriesGroupType( + 'barNode', + tf_graph.SeriesGroupingType.UNGROUP + ); + + // The Hierarchy's map should update. + expect(hierarchy.getSeriesGroupType('barNode')).toBe( + tf_graph.SeriesGroupingType.UNGROUP + ); + + // The original map should not. + expect(params.seriesMap.get('barNode')).toBe( + tf_graph.SeriesGroupingType.GROUP + ); + }); + + it('builds a toggled seriesMap', () => { + const hierarchy = new tf_graph_hierarchy.Hierarchy({ + ...tf_graph_hierarchy.DefaultHierarchyParams, + seriesMap: new Map([ + ['fooNode', tf_graph.SeriesGroupingType.UNGROUP], + ['barNode', tf_graph.SeriesGroupingType.GROUP], + ]), + }); + + const result1 = hierarchy.buildSeriesGroupMapToggled('fooNode'); + expect(result1).toEqual( + new Map([ + ['fooNode', tf_graph.SeriesGroupingType.GROUP], + ['barNode', tf_graph.SeriesGroupingType.GROUP], + ]) + ); + + const result2 = hierarchy.buildSeriesGroupMapToggled('barNode'); + expect(result2).toEqual( + new Map([ + ['fooNode', tf_graph.SeriesGroupingType.UNGROUP], + ['barNode', tf_graph.SeriesGroupingType.UNGROUP], + ]) + ); + + const result3 = hierarchy.buildSeriesGroupMapToggled('unknownNode'); + expect(result3).toEqual( + new Map([ + ['fooNode', tf_graph.SeriesGroupingType.UNGROUP], + ['barNode', tf_graph.SeriesGroupingType.GROUP], + ['unknownNode', tf_graph.SeriesGroupingType.UNGROUP], + ]) + ); + }); + }); }); diff --git a/tensorboard/plugins/graph/tf_graph_common/graph.ts b/tensorboard/plugins/graph/tf_graph_common/graph.ts index 57cb33c93e..77c19e9fa9 100644 --- a/tensorboard/plugins/graph/tf_graph_common/graph.ts +++ b/tensorboard/plugins/graph/tf_graph_common/graph.ts @@ -1452,22 +1452,6 @@ export function getGroupSeriesNodeButtonString(group: SeriesGroupingType) { return 'Group this series of nodes'; } } -/** - * Toggle the node series grouping option in the provided map, setting it - * to ungroup if the series is not already in the map. - */ -export function toggleNodeSeriesGroup( - map: { - [name: string]: SeriesGroupingType; - }, - name: string -) { - if (!(name in map) || map[name] === SeriesGroupingType.GROUP) { - map[name] = SeriesGroupingType.UNGROUP; - } else { - map[name] = SeriesGroupingType.GROUP; - } -} export interface Edges { control: Metaedge[]; diff --git a/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts b/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts index befaf3b452..c4d20e1c1b 100644 --- a/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts +++ b/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts @@ -55,6 +55,9 @@ export enum HierarchyEvent { TEMPLATES_UPDATED, } +// A map from the name of a series node to its grouping type. +type SeriesGroupMap = Map; + /** * Class for the Graph Hierarchy for TensorFlow graph. */ @@ -91,6 +94,7 @@ export class Hierarchy extends tf_graph_util.Dispatcher { private index: { [nodeName: string]: GroupNode | OpNode; }; + private readonly seriesGroupMap: SeriesGroupMap; constructor(params: HierarchyParams) { super(); @@ -98,6 +102,7 @@ export class Hierarchy extends tf_graph_util.Dispatcher { this.graphOptions.rankdir = params.rankDirection; this.root = createMetanode(ROOT_NAME, this.graphOptions); this.libraryFunctions = {}; + this.seriesGroupMap = new Map(params.seriesMap); this.devices = null; this.xlaClusters = null; this.verifyTemplate = params.verifyTemplate; @@ -109,6 +114,24 @@ export class Hierarchy extends tf_graph_util.Dispatcher { this.index[ROOT_NAME] = this.root; this.orderings = {}; } + getSeriesGroupType(nodeName: string): tf_graph.SeriesGroupingType { + // If grouping was not specified, assume it should be grouped by default. + return ( + this.seriesGroupMap.get(nodeName) ?? tf_graph.SeriesGroupingType.GROUP + ); + } + setSeriesGroupType(nodeName: string, groupType: tf_graph.SeriesGroupingType) { + return this.seriesGroupMap.set(nodeName, groupType); + } + buildSeriesGroupMapToggled( + nodeName: string + ): Map { + const newGroupType = + this.getSeriesGroupType(nodeName) === tf_graph.SeriesGroupingType.GROUP + ? tf_graph.SeriesGroupingType.UNGROUP + : tf_graph.SeriesGroupingType.GROUP; + return new Map([...this.seriesGroupMap, [nodeName, newGroupType]]); + } getNodeMap(): { [nodeName: string]: GroupNode | OpNode; } { @@ -442,9 +465,8 @@ function findEdgeTargetsInGraph( export interface HierarchyParams { verifyTemplate: boolean; seriesNodeMinSize: number; - seriesMap: { - [name: string]: tf_graph.SeriesGroupingType; - }; + // The initial map of explicit series group types. + seriesMap: SeriesGroupMap; // This string is supplied to dagre as the 'rankdir' property for laying out // the graph. TB, BT, LR, or RL. The default is 'BT' (bottom to top). rankDirection: string; @@ -456,7 +478,7 @@ export interface HierarchyParams { export const DefaultHierarchyParams: HierarchyParams = { verifyTemplate: true, seriesNodeMinSize: 5, - seriesMap: {}, + seriesMap: new Map(), rankDirection: 'BT', useGeneralizedSeriesPatterns: false, }; @@ -579,11 +601,10 @@ export function joinAndAggregateStats( }); } export function getIncompatibleOps( - hierarchy: Hierarchy, - hierarchyParams: HierarchyParams -) { - let nodes: (GroupNode | OpNode)[] = []; - let addedSeriesNodes: { + hierarchy: Hierarchy +): Array { + const nodes: Array = []; + const addedSeriesNodes: { [seriesName: string]: SeriesNode; } = {}; _.each(hierarchy.root.leaves(), (leaf) => { @@ -593,9 +614,8 @@ export function getIncompatibleOps( if (!opNode.compatible) { if (opNode.owningSeries) { if ( - hierarchyParams && - hierarchyParams.seriesMap[opNode.owningSeries] === - tf_graph.SeriesGroupingType.UNGROUP + hierarchy.getSeriesGroupType(opNode.owningSeries) === + tf_graph.SeriesGroupingType.UNGROUP ) { // For un-grouped series node, add each node individually nodes.push(opNode); @@ -843,9 +863,7 @@ function groupSeries( [name: string]: string; }, threshold: number, - map: { - [name: string]: tf_graph.SeriesGroupingType; - }, + seriesMap: SeriesGroupMap, useGeneralizedSeriesPatterns: boolean ) { let metagraph = metanode.metagraph; @@ -857,7 +875,7 @@ function groupSeries( hierarchy, seriesNames, threshold, - map, + seriesMap, useGeneralizedSeriesPatterns ); } @@ -881,16 +899,22 @@ function groupSeries( child.owningSeries = seriesName; } }); - // If the series contains less than the threshold number of nodes and - // this series has not been adding to the series map, then set this - // series to be shown ungrouped in the map. - if (nodeMemberNames.length < threshold && !(seriesNode.name in map)) { - map[seriesNode.name] = tf_graph.SeriesGroupingType.UNGROUP; + // If the series contains less than the threshold number of nodes, then set + // this series to be shown ungrouped in the map. + if ( + nodeMemberNames.length < threshold && + hierarchy.getSeriesGroupType(seriesNode.name) === + tf_graph.SeriesGroupingType.GROUP + ) { + hierarchy.setSeriesGroupType( + seriesNode.name, + tf_graph.SeriesGroupingType.UNGROUP + ); } // If the series is in the map as ungrouped then do not group the series. if ( - seriesNode.name in map && - map[seriesNode.name] === tf_graph.SeriesGroupingType.UNGROUP + hierarchy.getSeriesGroupType(seriesNode.name) === + tf_graph.SeriesGroupingType.UNGROUP ) { return; } diff --git a/tensorboard/plugins/graph/tf_graph_info/tf-graph-info.ts b/tensorboard/plugins/graph/tf_graph_info/tf-graph-info.ts index 9d8db20c0e..ec6aeef709 100644 --- a/tensorboard/plugins/graph/tf_graph_info/tf-graph-info.ts +++ b/tensorboard/plugins/graph/tf_graph_info/tf-graph-info.ts @@ -63,7 +63,6 @@ class TfGraphInfo extends LegacyElementMixin(PolymerElement) {