Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions tensorboard/plugins/graph/tf_graph/tf-graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class TfGraph extends LegacyElementMixin(PolymerElement) {
@property({type: Object})
devicesForStats: object;
@property({type: Object})
hierarchyParams: any;
hierarchyParams: tf_graph_hierarchy.HierarchyParams;
@property({
type: Object,
notify: true,
Expand Down Expand Up @@ -479,9 +479,6 @@ class TfGraph extends LegacyElementMixin(PolymerElement) {
this.nodeToggleSeriesGroup(nodeName);
}
nodeToggleSeriesGroup(nodeName) {
// Toggle the group setting of the specified node appropriately.
tf_graph.toggleNodeSeriesGroup(this.hierarchyParams.seriesMap, nodeName);
// Rebuild the render hierarchy with the updated series grouping map.
this.set('progress', {
value: 0,
msg: '',
Expand All @@ -492,8 +489,15 @@ class TfGraph extends LegacyElementMixin(PolymerElement) {
100,
'Namespace hierarchy'
);

// Toggle the node's group type, setting to 'UNGROUP' if unspecified.
const newHierarchyParams = {
...this.hierarchyParams,
seriesMap: this.graphHierarchy.buildSeriesGroupMapToggled(nodeName),
};

tf_graph_hierarchy
.build(this.basicGraph, this.hierarchyParams, hierarchyTracker)
.build(this.basicGraph, newHierarchyParams, hierarchyTracker)
.then(
function (graphHierarchy) {
this.set('graphHierarchy', graphHierarchy);
Expand Down
10 changes: 7 additions & 3 deletions tensorboard/plugins/graph/tf_graph_board/tf-graph-board.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import * as tf_graph from '../tf_graph_common/graph';
import * as tf_graph_render from '../tf_graph_common/render';
import {LegacyElementMixin} from '../../../components/polymer/legacy_element_mixin';
import {ColorBy} from '../tf_graph_common/view_types';
import {Hierarchy} from '../tf_graph_common/hierarchy';
import * as tf_graph_hierarchy from '../tf_graph_common/hierarchy';

/**
* Some UX features, such as 'color by structure', rely on the 'template'
Expand Down Expand Up @@ -174,7 +174,6 @@ class TfGraphBoard extends LegacyElementMixin(PolymerElement) {
id="graph-info"
title="selected"
graph-hierarchy="[[graphHierarchy]]"
hierarchy-params="[[hierarchyParams]]"
render-hierarchy="[[renderHierarchy]]"
graph="[[graph]]"
selected-node="{{selectedNode}}"
Expand All @@ -197,9 +196,14 @@ class TfGraphBoard extends LegacyElementMixin(PolymerElement) {
</div>
`;
@property({type: Object})
graphHierarchy: Hierarchy;
graphHierarchy: tf_graph_hierarchy.Hierarchy;
@property({type: Object})
graph: tf_graph.SlimGraph;
// TODO(psybuzz): ideally, this would be a required property and the component
// that owns <tf-graph-board> and the graph loader should create these params.
@property({type: Object})
hierarchyParams: tf_graph_hierarchy.HierarchyParams =
tf_graph_hierarchy.DefaultHierarchyParams;
@property({type: Object})
stats: object;
/**
Expand Down
66 changes: 66 additions & 0 deletions tensorboard/plugins/graph/tf_graph_common/common_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import * as tb_debug from '../../../components/tb_debug';
import * as tf_graph from './graph';
import * as tf_graph_hierarchy from './hierarchy';
import * as tf_graph_loader from './loader';
import * as tf_graph_parser from './parser';

Expand Down Expand Up @@ -127,4 +129,68 @@ describe('graph tests', () => {
},
]);
});

describe('Graph Hierarchy', () => {
it('does not mutate the provided seriesMap by reference', () => {
const readonlySeriesMap = new Map([
['fooNode', tf_graph.SeriesGroupingType.UNGROUP],
['barNode', tf_graph.SeriesGroupingType.GROUP],
]);
Comment on lines +135 to +138
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idea: another way to have this test would be to Object.freeze it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting idea, although I don't know of any way to freeze a Map without writing a custom FrozenMap.

const params = {
...tf_graph_hierarchy.DefaultHierarchyParams,
seriesMap: readonlySeriesMap,
};

const hierarchy = new tf_graph_hierarchy.Hierarchy(params);
hierarchy.setSeriesGroupType(
'barNode',
tf_graph.SeriesGroupingType.UNGROUP
);

// The Hierarchy's map should update.
expect(hierarchy.getSeriesGroupType('barNode')).toBe(
tf_graph.SeriesGroupingType.UNGROUP
);

// The original map should not.
expect(params.seriesMap.get('barNode')).toBe(
tf_graph.SeriesGroupingType.GROUP
);
});

it('builds a toggled seriesMap', () => {
const hierarchy = new tf_graph_hierarchy.Hierarchy({
...tf_graph_hierarchy.DefaultHierarchyParams,
seriesMap: new Map([
['fooNode', tf_graph.SeriesGroupingType.UNGROUP],
['barNode', tf_graph.SeriesGroupingType.GROUP],
]),
});

const result1 = hierarchy.buildSeriesGroupMapToggled('fooNode');
expect(result1).toEqual(
new Map([
['fooNode', tf_graph.SeriesGroupingType.GROUP],
['barNode', tf_graph.SeriesGroupingType.GROUP],
])
);

const result2 = hierarchy.buildSeriesGroupMapToggled('barNode');
expect(result2).toEqual(
new Map([
['fooNode', tf_graph.SeriesGroupingType.UNGROUP],
['barNode', tf_graph.SeriesGroupingType.UNGROUP],
])
);

const result3 = hierarchy.buildSeriesGroupMapToggled('unknownNode');
expect(result3).toEqual(
new Map([
['fooNode', tf_graph.SeriesGroupingType.UNGROUP],
['barNode', tf_graph.SeriesGroupingType.GROUP],
['unknownNode', tf_graph.SeriesGroupingType.UNGROUP],
])
);
});
});
});
16 changes: 0 additions & 16 deletions tensorboard/plugins/graph/tf_graph_common/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1452,22 +1452,6 @@ export function getGroupSeriesNodeButtonString(group: SeriesGroupingType) {
return 'Group this series of nodes';
}
}
/**
* Toggle the node series grouping option in the provided map, setting it
* to ungroup if the series is not already in the map.
*/
export function toggleNodeSeriesGroup(
map: {
[name: string]: SeriesGroupingType;
},
name: string
) {
if (!(name in map) || map[name] === SeriesGroupingType.GROUP) {
map[name] = SeriesGroupingType.UNGROUP;
} else {
map[name] = SeriesGroupingType.GROUP;
}
}

export interface Edges {
control: Metaedge[];
Expand Down
70 changes: 47 additions & 23 deletions tensorboard/plugins/graph/tf_graph_common/hierarchy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ export enum HierarchyEvent {
TEMPLATES_UPDATED,
}

// A map from the name of a series node to its grouping type.
type SeriesGroupMap = Map<string, tf_graph.SeriesGroupingType>;

/**
* Class for the Graph Hierarchy for TensorFlow graph.
*/
Expand Down Expand Up @@ -91,13 +94,15 @@ export class Hierarchy extends tf_graph_util.Dispatcher<HierarchyEvent> {
private index: {
[nodeName: string]: GroupNode | OpNode;
};
private readonly seriesGroupMap: SeriesGroupMap;

constructor(params: HierarchyParams) {
super();
this.graphOptions.compound = true;
this.graphOptions.rankdir = params.rankDirection;
this.root = createMetanode(ROOT_NAME, this.graphOptions);
this.libraryFunctions = {};
this.seriesGroupMap = new Map(params.seriesMap);
this.devices = null;
this.xlaClusters = null;
this.verifyTemplate = params.verifyTemplate;
Expand All @@ -109,6 +114,24 @@ export class Hierarchy extends tf_graph_util.Dispatcher<HierarchyEvent> {
this.index[ROOT_NAME] = this.root;
this.orderings = {};
}
getSeriesGroupType(nodeName: string): tf_graph.SeriesGroupingType {
// If grouping was not specified, assume it should be grouped by default.
return (
this.seriesGroupMap.get(nodeName) ?? tf_graph.SeriesGroupingType.GROUP
);
}
setSeriesGroupType(nodeName: string, groupType: tf_graph.SeriesGroupingType) {
return this.seriesGroupMap.set(nodeName, groupType);
}
buildSeriesGroupMapToggled(
nodeName: string
): Map<string, tf_graph.SeriesGroupingType> {
const newGroupType =
this.getSeriesGroupType(nodeName) === tf_graph.SeriesGroupingType.GROUP
? tf_graph.SeriesGroupingType.UNGROUP
: tf_graph.SeriesGroupingType.GROUP;
return new Map([...this.seriesGroupMap, [nodeName, newGroupType]]);
}
getNodeMap(): {
[nodeName: string]: GroupNode | OpNode;
} {
Expand Down Expand Up @@ -442,9 +465,8 @@ function findEdgeTargetsInGraph(
export interface HierarchyParams {
verifyTemplate: boolean;
seriesNodeMinSize: number;
seriesMap: {
[name: string]: tf_graph.SeriesGroupingType;
};
// The initial map of explicit series group types.
seriesMap: SeriesGroupMap;
// This string is supplied to dagre as the 'rankdir' property for laying out
// the graph. TB, BT, LR, or RL. The default is 'BT' (bottom to top).
rankDirection: string;
Expand All @@ -456,7 +478,7 @@ export interface HierarchyParams {
export const DefaultHierarchyParams: HierarchyParams = {
verifyTemplate: true,
seriesNodeMinSize: 5,
seriesMap: {},
seriesMap: new Map(),
rankDirection: 'BT',
useGeneralizedSeriesPatterns: false,
};
Expand Down Expand Up @@ -579,11 +601,10 @@ export function joinAndAggregateStats(
});
}
export function getIncompatibleOps(
hierarchy: Hierarchy,
hierarchyParams: HierarchyParams
) {
let nodes: (GroupNode | OpNode)[] = [];
let addedSeriesNodes: {
hierarchy: Hierarchy
): Array<GroupNode | OpNode> {
const nodes: Array<GroupNode | OpNode> = [];
const addedSeriesNodes: {
[seriesName: string]: SeriesNode;
} = {};
_.each(hierarchy.root.leaves(), (leaf) => {
Expand All @@ -593,9 +614,8 @@ export function getIncompatibleOps(
if (!opNode.compatible) {
if (opNode.owningSeries) {
if (
hierarchyParams &&
hierarchyParams.seriesMap[opNode.owningSeries] ===
tf_graph.SeriesGroupingType.UNGROUP
hierarchy.getSeriesGroupType(opNode.owningSeries) ===
tf_graph.SeriesGroupingType.UNGROUP
) {
// For un-grouped series node, add each node individually
nodes.push(opNode);
Expand Down Expand Up @@ -843,9 +863,7 @@ function groupSeries(
[name: string]: string;
},
threshold: number,
map: {
[name: string]: tf_graph.SeriesGroupingType;
},
seriesMap: SeriesGroupMap,
useGeneralizedSeriesPatterns: boolean
) {
let metagraph = metanode.metagraph;
Expand All @@ -857,7 +875,7 @@ function groupSeries(
hierarchy,
seriesNames,
threshold,
map,
seriesMap,
useGeneralizedSeriesPatterns
);
}
Expand All @@ -881,16 +899,22 @@ function groupSeries(
child.owningSeries = seriesName;
}
});
// If the series contains less than the threshold number of nodes and
// this series has not been adding to the series map, then set this
// series to be shown ungrouped in the map.
if (nodeMemberNames.length < threshold && !(seriesNode.name in map)) {
map[seriesNode.name] = tf_graph.SeriesGroupingType.UNGROUP;
// If the series contains less than the threshold number of nodes, then set
// this series to be shown ungrouped in the map.
if (
nodeMemberNames.length < threshold &&
hierarchy.getSeriesGroupType(seriesNode.name) ===
tf_graph.SeriesGroupingType.GROUP
Comment on lines +906 to +907
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous code reads !(seriesNode.name in map) and it makes me think that this code is basically checking to see if the hierarchy.getSeriesGroupType(seriesNode.name) has the default value in case it does not exist. I think it is more prudent to null to Map.prototype.has than matching against the default value defined in various places in this module.

Would it not read better if the hierarchy.getSeriesGroupType(seriesNode.name) were to return null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still prefer to hide the fact that the group types are stored implicitly, because it is a concern that is specific to Hierarchy internals. External consumers of the Hierarchy actually do not need to know what null values mean, so exposing only Group/Ungroup keeps consumers less complex and easier to reason about, imo.

While the previous code does check whether a node has been explicitly grouped/ungrouped, the new code matches the same intent (if a grouped series is too short, then ungroup it), and does not require readers to understand anything about the default group type.

I've updated the comment above on L900 accordingly.

) {
hierarchy.setSeriesGroupType(
seriesNode.name,
tf_graph.SeriesGroupingType.UNGROUP
);
}
// If the series is in the map as ungrouped then do not group the series.
if (
seriesNode.name in map &&
map[seriesNode.name] === tf_graph.SeriesGroupingType.UNGROUP
hierarchy.getSeriesGroupType(seriesNode.name) ===
tf_graph.SeriesGroupingType.UNGROUP
) {
return;
}
Expand Down
1 change: 0 additions & 1 deletion tensorboard/plugins/graph/tf_graph_info/tf-graph-info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class TfGraphInfo extends LegacyElementMixin(PolymerElement) {
<template is="dom-if" if="[[_equals(colorBy, 'op_compatibility')]]">
<tf-graph-op-compat-card
graph-hierarchy="[[graphHierarchy]]"
hierarchy-params="[[hierarchyParams]]"
render-hierarchy="[[renderHierarchy]]"
color-by="[[colorBy]]"
node-title="[[compatNodeTitle]]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ class TfGraphOpCompatCard extends LegacyElementMixin(PolymerElement) {
@property({type: Object})
graphHierarchy: tf_graph_hierarchy.Hierarchy;
@property({type: Object})
hierarchyParams: object;
@property({type: Object})
renderHierarchy: tf_graph_render.RenderGraphInfo;
@property({type: String})
nodeTitle: string;
Expand Down Expand Up @@ -219,17 +217,14 @@ class TfGraphOpCompatCard extends LegacyElementMixin(PolymerElement) {
list.fire('iron-resize');
}
}
@computed('graphHierarchy', 'hierarchyParams')
get _incompatibleOpNodes(): object {
var graphHierarchy = this.graphHierarchy;
var hierarchyParams = this.hierarchyParams;
if (graphHierarchy && graphHierarchy.root) {
this.async(this._resizeList.bind(this, '#incompatibleOpsList'));
return tf_graph_hierarchy.getIncompatibleOps(
graphHierarchy,
hierarchyParams as any
);
@computed('graphHierarchy')
get _incompatibleOpNodes(): Array<tf_graph.GroupNode | tf_graph.OpNode> {
const graphHierarchy = this.graphHierarchy;
if (!graphHierarchy || !graphHierarchy.root) {
return [];
}
this.async(this._resizeList.bind(this, '#incompatibleOpsList'));
return tf_graph_hierarchy.getIncompatibleOps(graphHierarchy);
}
@computed('graphHierarchy')
get _opCompatScore(): number {
Expand Down