From 76993a5de1a1dfff8c2085df5465bc9590fc5db7 Mon Sep 17 00:00:00 2001 From: Manuel Klimek Date: Fri, 3 Aug 2018 20:14:52 +0200 Subject: [PATCH 1/2] Add histogram coloring for xla clusters similarly to devices. --- .../plugins/graph/tf_graph_common/graph.ts | 10 +++ .../graph/tf_graph_common/hierarchy.ts | 22 +++++- .../plugins/graph/tf_graph_common/node.ts | 65 +++++++++++------- .../plugins/graph/tf_graph_common/render.ts | 68 ++++++++++++------- .../tf_graph_controls/tf-graph-controls.html | 26 +++---- 5 files changed, 123 insertions(+), 68 deletions(-) diff --git a/tensorboard/plugins/graph/tf_graph_common/graph.ts b/tensorboard/plugins/graph/tf_graph_common/graph.ts index 16f7901e21..f715682f74 100644 --- a/tensorboard/plugins/graph/tf_graph_common/graph.ts +++ b/tensorboard/plugins/graph/tf_graph_common/graph.ts @@ -295,6 +295,12 @@ export interface GroupNode extends Node { */ deviceHistogram: {[device: string]: number}; + /** + * Stores how many times each XLA cluster name appears in its children + * op nodes. Used to color group nodes by XLA clusters. + */ + xlaClusterHistogram: {[device: string]: number}; + /** * Stores how many ops in sub-graph were compatible and how many are * incompatible. @@ -595,6 +601,7 @@ export class MetanodeImpl implements Metanode { templateId: string; opHistogram: {[op: string]: number}; deviceHistogram: {[op: string]: number}; + xlaClusterHistogram: {[op: string]: number}; compatibilityHistogram: {compatible: number, incompatible: number}; parentNode: Node; hasNonControlEdges: boolean; @@ -624,6 +631,7 @@ export class MetanodeImpl implements Metanode { */ this.opHistogram = {}; this.deviceHistogram = {}; + this.xlaClusterHistogram = {}; this.compatibilityHistogram = {compatible: 0, incompatible: 0}; /** unique id for a metanode of similar subgraph */ this.templateId = null; @@ -831,6 +839,7 @@ class SeriesNodeImpl implements SeriesNode { bridgegraph: graphlib.Graph; parentNode: Node; deviceHistogram: {[op: string]: number}; + xlaClusterHistogram: {[op: string]: number}; compatibilityHistogram: {compatible: number, incompatible: number}; hasNonControlEdges: boolean; include: InclusionType; @@ -859,6 +868,7 @@ class SeriesNodeImpl implements SeriesNode { this.bridgegraph = null; this.parentNode = null; this.deviceHistogram = {}; + this.xlaClusterHistogram = {}; this.compatibilityHistogram = {compatible: 0, incompatible: 0}; this.hasNonControlEdges = false; this.include = InclusionType.UNSPECIFIED; diff --git a/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts b/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts index c4299df5ee..51d0dd096c 100644 --- a/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts +++ b/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts @@ -87,6 +87,7 @@ class HierarchyImpl implements Hierarchy { this.libraryFunctions = {}; this.templates = null; this.devices = null; + this.xlaClusters = null; /** * @type {Object} Dictionary object that maps node name to the node * (could be op-node, metanode, or series-node) @@ -475,15 +476,20 @@ export function build(graph: tf.graph.SlimGraph, params: HierarchyParams, export function joinAndAggregateStats( h: Hierarchy, stats: tf.graph.proto.StepStats) { - // Get all the possible device names. + // Get all the possible device and XLA cluster names. let deviceNames = {}; + let xlaClusterNames = {}; _.each(h.root.leaves(), nodeName => { let leaf = h.node(nodeName); if (leaf.device != null) { deviceNames[leaf.device] = true; } + if (leaf.xlaCluster != null) { + xlaClusterNames[leaf.xlaCluster] = true; + } }); h.devices = _.keys(deviceNames); + h.xlaClusters = _.keys(xlaClusterNames); // Reset stats for each group node. _.each(h.getNodeMap(), (node, nodeName) => { @@ -502,6 +508,12 @@ export function joinAndAggregateStats( let deviceHistogram = (node.parentNode).deviceHistogram; deviceHistogram[leaf.device] = (deviceHistogram[leaf.device] || 0) + 1; } + if (leaf.xlaCluster != null) { + let xlaClusterHistogram = + (node.parentNode).xlaClusterHistogram; + xlaClusterHistogram[leaf.xlaCluster] = + (xlaClusterHistogram[leaf.xlaCluster] || 0) + 1; + } if (leaf.stats != null) { node.parentNode.stats.combine(leaf.stats); } @@ -592,6 +604,10 @@ function addNodes(h: Hierarchy, graph: SlimGraph) { parent.deviceHistogram[node.device] = (parent.deviceHistogram[node.device] || 0) + 1; } + if (node.xlaCluster != null) { + parent.xlaClusterHistogram[node.xlaCluster] = + (parent.xlaClusterHistogram[node.xlaCluster] || 0) + 1; + } // Increment parents appropriate compatibility count if (node.compatible) { @@ -827,6 +843,10 @@ function groupSeries(metanode: Metanode, hierarchy: Hierarchy, seriesNode.deviceHistogram[child.device] = (seriesNode.deviceHistogram[child.device] || 0) + 1; } + if (child.xlaCluster != null) { + seriesNode.xlaClusterHistogram[child.xlaCluster] = + (seriesNode.xlaClusterHistogram[child.xlaCluster] || 0) + 1; + } // Increment parents appropriate compatibility count if (child.compatible) { diff --git a/tensorboard/plugins/graph/tf_graph_common/node.ts b/tensorboard/plugins/graph/tf_graph_common/node.ts index a8f0b10e10..6e636c409d 100644 --- a/tensorboard/plugins/graph/tf_graph_common/node.ts +++ b/tensorboard/plugins/graph/tf_graph_common/node.ts @@ -612,6 +612,32 @@ function position(nodeGroup, d: render.RenderNodeInfo) { export enum ColorBy {STRUCTURE, DEVICE, XLA_CLUSTER, COMPUTE_TIME, MEMORY, OP_COMPATIBILITY}; +function getGradient( + id: string, colors: Array<{color: string, proportion: number}>) { + let escapedId = tf.graph.util.escapeQuerySelector(id); + let gradientDefs = d3.select('svg#svg defs #linearGradients'); + let linearGradient = gradientDefs.select('linearGradient#' + escapedId); + // If the linear gradient is not there yet, create it. + if (linearGradient.size() === 0) { + linearGradient = gradientDefs.append('linearGradient').attr('id', id); + // Re-create the stops of the linear gradient. + linearGradient.selectAll('*').remove(); + let cumulativeProportion = 0; + // For each color, create a stop using the proportion of that device. + _.each(colors, d => { + let color = d.color; + linearGradient.append('stop') + .attr('offset', cumulativeProportion) + .attr('stop-color', color); + linearGradient.append('stop') + .attr('offset', cumulativeProportion + d.proportion) + .attr('stop-color', color); + cumulativeProportion += d.proportion; + }); + } + return `url(#${escapedId})`; +} + /** * Returns the fill color for the node given its state and the 'color by' * option. @@ -650,32 +676,21 @@ export function getFillForNode(templateIndex, colorBy, // Return the hue for unknown device. return colorParams.UNKNOWN; } - let id = renderInfo.node.name; - let escapedId = tf.graph.util.escapeQuerySelector(id); - let gradientDefs = d3.select('svg#svg defs #linearGradients'); - let linearGradient = gradientDefs.select('linearGradient#' + escapedId); - // If the linear gradient is not there yet, create it. - if (linearGradient.size() === 0) { - linearGradient = gradientDefs.append('linearGradient').attr('id', id); - // Re-create the stops of the linear gradient. - linearGradient.selectAll('*').remove(); - let cumulativeProportion = 0; - // For each device, create a stop using the proportion of that device. - _.each(renderInfo.deviceColors, d => { - let color = d.color; - linearGradient.append('stop') - .attr('offset', cumulativeProportion) - .attr('stop-color', color); - linearGradient.append('stop') - .attr('offset', cumulativeProportion + d.proportion) - .attr('stop-color', color); - cumulativeProportion += d.proportion; - }); - } - return isExpanded ? colorParams.EXPANDED_COLOR : `url(#${escapedId})`; + return isExpanded ? + colorParams.EXPANDED_COLOR : + getGradient( + 'device-' + renderInfo.node.name, renderInfo.deviceColors); + case ColorBy.XLA_CLUSTER: - return isExpanded ? colorParams.EXPANDED_COLOR : - renderInfo.xlaClusterColor || colorParams.UNKNOWN; + if (renderInfo.xlaClusterColors == null) { + // Return the hue for unknown xlaCluster. + return colorParams.UNKNOWN; + } + return isExpanded ? + colorParams.EXPANDED_COLOR : + getGradient( + 'xla-' + renderInfo.node.name, renderInfo.xlaClusterColors); + case ColorBy.COMPUTE_TIME: return isExpanded ? colorParams.EXPANDED_COLOR : renderInfo.computeTimeColor || diff --git a/tensorboard/plugins/graph/tf_graph_common/render.ts b/tensorboard/plugins/graph/tf_graph_common/render.ts index 3bca699ef3..f41674e18c 100644 --- a/tensorboard/plugins/graph/tf_graph_common/render.ts +++ b/tensorboard/plugins/graph/tf_graph_common/render.ts @@ -293,6 +293,24 @@ export class RenderGraphInfo { return this.hierarchy.node(nodeName); } + private colorHistogram( + histogram: {[name: string]: number}, + colors: d3.ScaleOrdinal): + Array<{color: string, proportion: number}> { + let pairs = _.pairs(histogram); + if (pairs.length > 0) { + // Compute the total # of items. + let numItems = _.sum(pairs, _.last); + return _.map(pairs, pair => ({ + color: colors(pair[0]), + // Normalize to a proportion of total # of items. + proportion: pair[1] / numItems + })); + } + console.info('no pairs found!'); + return null; + } + /** * Get a previously created RenderNodeInfo for the specified node name, * or create one if it hasn't been created yet. @@ -326,39 +344,33 @@ export class RenderGraphInfo { this.computeTimeScale(node.stats.getTotalMicros()); } - if (!node.isGroupNode) { - let clusterName = (node as OpNode).xlaCluster; - if (clusterName) { - renderInfo.xlaClusterColor = this.xlaClusterColorMap(clusterName); - } - } - // We only fade nodes when we're displaying stats. renderInfo.isFadedOut = this.displayingStats && !tf.graph.util.hasDisplayableNodeStats(node.stats); + var deviceHistogram = null; + var xlaClusterHistogram = null; if (node.isGroupNode) { - // Make a list of tuples (device, proportion), where proportion - // is the fraction of op nodes that have that device. - let pairs = _.pairs((node).deviceHistogram); - if (pairs.length > 0) { - // Compute the total # of devices. - let numDevices = _.sum(pairs, _.last); - renderInfo.deviceColors = _.map(pairs, pair => ({ - color: this.deviceColorMap(pair[0]), - // Normalize to a proportion of total # of devices. - proportion: pair[1] / numDevices - })); - } + deviceHistogram = (node).deviceHistogram; + xlaClusterHistogram = (node).xlaClusterHistogram; } else { let device = (renderInfo.node).device; if (device) { - renderInfo.deviceColors = [{ - color: this.deviceColorMap(device), - proportion: 1.0 - }]; + deviceHistogram = {[device]: 1}; + } + let xlaCluster = (renderInfo.node).xlaCluster; + if (xlaCluster) { + xlaClusterHistogram = {[xlaCluster]: 1}; } } + if (deviceHistogram) { + renderInfo.deviceColors = + this.colorHistogram(deviceHistogram, this.deviceColorMap); + } + if (xlaClusterHistogram) { + renderInfo.xlaClusterColors = + this.colorHistogram(xlaClusterHistogram, this.xlaClusterColorMap); + } return this.index[nodeName]; } @@ -578,6 +590,8 @@ export class RenderGraphInfo { newMetanode.templateId = libraryMetanode.templateId; newMetanode.opHistogram = _.clone(libraryMetanode.opHistogram); newMetanode.deviceHistogram = _.clone(libraryMetanode.deviceHistogram); + newMetanode.xlaClusterHistogram = + _.clone(libraryMetanode.xlaClusterHistogram); newMetanode.hasNonControlEdges = libraryMetanode.hasNonControlEdges; newMetanode.include = libraryMetanode.include; newMetanode.nodeAttributes = _.clone(libraryMetanode.nodeAttributes); @@ -704,7 +718,7 @@ export class RenderGraphInfo { // metagraph. const originalMetaEdges = this.hierarchy.getPredecessors( opNodeToReplace.name); - + // Find the metaedge that the input index corresponds to. // A metaedge may correspond to several edges. For instance, // an edge may enter a series node. @@ -1568,9 +1582,11 @@ export class RenderNodeInfo { deviceColors: Array<{color: string, proportion: number}>; /** - * Color according to the XLA cluster of this node. + * List of (color, proportion) tuples based on the proportion of xlaClusters + * of its children. If this node is an op node, this list will have only one + * color with proportion 1.0. */ - xlaClusterColor: string; + xlaClusterColors: Array<{color: string, proportion: number}>; /** * Color according to the memory usage of this node. diff --git a/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html b/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html index dcb86e6e8d..5c434f37e3 100644 --- a/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html +++ b/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html @@ -487,22 +487,16 @@