diff --git a/tensorboard/plugins/graph/tf_graph_app/demo/data/graph.pbtxt b/tensorboard/plugins/graph/tf_graph_app/demo/data/graph.pbtxt index 8b95b258df..4885a8ace7 100644 --- a/tensorboard/plugins/graph/tf_graph_app/demo/data/graph.pbtxt +++ b/tensorboard/plugins/graph/tf_graph_app/demo/data/graph.pbtxt @@ -62,8 +62,9 @@ node { } } node { - name: "Add" + name: "0/0/Add" op: "Add" + device: "1" input: "life" input: "universe" attr { @@ -72,11 +73,78 @@ node { type: DT_INT32 } } + attr { + key: "_XlaCluster" + value { + s: "cluster_0" + } + } +} +node { + name: "0/1/Add" + op: "Add" + device: "2" + input: "life" + input: "universe" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } +} +node { + name: "1/Add" + op: "Add" + device: "tpu" + input: "life" + input: "universe" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_2" + } + } +} +node { + name: "1/Weird" + op: "Weird" + device: "tpu" + input: "life" + input: "universe" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_3" + } + } } node { name: "answer" op: "Add" - input: "Add" + device: "2" + input: "0/0/Add" + input: "0/1/Add" + input: "1/Add" + input: "1/Weird" input: "everything" attr { key: "T" @@ -84,6 +152,12 @@ node { type: DT_INT32 } } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } } versions { producer: 10 diff --git a/tensorboard/plugins/graph/tf_graph_app/demo/index.html b/tensorboard/plugins/graph/tf_graph_app/demo/index.html index f71feea390..bef8527f0a 100644 --- a/tensorboard/plugins/graph/tf_graph_app/demo/index.html +++ b/tensorboard/plugins/graph/tf_graph_app/demo/index.html @@ -23,7 +23,7 @@ /** Make the graph app tall enough so the bottom legend does not overlap with the top. */ tf-graph-app, .container.tf-graph-app { display: block; - height: 700px; + height: 900px; }

Answer to the Ultimate Question of Life, the Universe, and Everything

diff --git a/tensorboard/plugins/graph/tf_graph_app/tf-graph-app.html b/tensorboard/plugins/graph/tf_graph_app/tf-graph-app.html index 1be0e3a983..b02a01d63a 100644 --- a/tensorboard/plugins/graph/tf_graph_app/tf-graph-app.html +++ b/tensorboard/plugins/graph/tf_graph_app/tf-graph-app.html @@ -90,12 +90,14 @@ color-by="{{colorBy}}" render-hierarchy="[[_renderHierarchy]]" selected-node="{{selectedNode}}" + selected-file="{{selectedFile}}" >
diff --git a/tensorboard/plugins/graph/tf_graph_common/graph.ts b/tensorboard/plugins/graph/tf_graph_common/graph.ts index 16f7901e21..f715682f74 100644 --- a/tensorboard/plugins/graph/tf_graph_common/graph.ts +++ b/tensorboard/plugins/graph/tf_graph_common/graph.ts @@ -295,6 +295,12 @@ export interface GroupNode extends Node { */ deviceHistogram: {[device: string]: number}; + /** + * Stores how many times each XLA cluster name appears in its children + * op nodes. Used to color group nodes by XLA clusters. + */ + xlaClusterHistogram: {[device: string]: number}; + /** * Stores how many ops in sub-graph were compatible and how many are * incompatible. @@ -595,6 +601,7 @@ export class MetanodeImpl implements Metanode { templateId: string; opHistogram: {[op: string]: number}; deviceHistogram: {[op: string]: number}; + xlaClusterHistogram: {[op: string]: number}; compatibilityHistogram: {compatible: number, incompatible: number}; parentNode: Node; hasNonControlEdges: boolean; @@ -624,6 +631,7 @@ export class MetanodeImpl implements Metanode { */ this.opHistogram = {}; this.deviceHistogram = {}; + this.xlaClusterHistogram = {}; this.compatibilityHistogram = {compatible: 0, incompatible: 0}; /** unique id for a metanode of similar subgraph */ this.templateId = null; @@ -831,6 +839,7 @@ class SeriesNodeImpl implements SeriesNode { bridgegraph: graphlib.Graph; parentNode: Node; deviceHistogram: {[op: string]: number}; + xlaClusterHistogram: {[op: string]: number}; compatibilityHistogram: {compatible: number, incompatible: number}; hasNonControlEdges: boolean; include: InclusionType; @@ -859,6 +868,7 @@ class SeriesNodeImpl implements SeriesNode { this.bridgegraph = null; this.parentNode = null; this.deviceHistogram = {}; + this.xlaClusterHistogram = {}; this.compatibilityHistogram = {compatible: 0, incompatible: 0}; this.hasNonControlEdges = false; this.include = InclusionType.UNSPECIFIED; diff --git a/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts b/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts index c4299df5ee..51d0dd096c 100644 --- a/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts +++ b/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts @@ -87,6 +87,7 @@ class HierarchyImpl implements Hierarchy { this.libraryFunctions = {}; this.templates = null; this.devices = null; + this.xlaClusters = null; /** * @type {Object} Dictionary object that maps node name to the node * (could be op-node, metanode, or series-node) @@ -475,15 +476,20 @@ export function build(graph: tf.graph.SlimGraph, params: HierarchyParams, export function joinAndAggregateStats( h: Hierarchy, stats: tf.graph.proto.StepStats) { - // Get all the possible device names. + // Get all the possible device and XLA cluster names. let deviceNames = {}; + let xlaClusterNames = {}; _.each(h.root.leaves(), nodeName => { let leaf = h.node(nodeName); if (leaf.device != null) { deviceNames[leaf.device] = true; } + if (leaf.xlaCluster != null) { + xlaClusterNames[leaf.xlaCluster] = true; + } }); h.devices = _.keys(deviceNames); + h.xlaClusters = _.keys(xlaClusterNames); // Reset stats for each group node. _.each(h.getNodeMap(), (node, nodeName) => { @@ -502,6 +508,12 @@ export function joinAndAggregateStats( let deviceHistogram = (node.parentNode).deviceHistogram; deviceHistogram[leaf.device] = (deviceHistogram[leaf.device] || 0) + 1; } + if (leaf.xlaCluster != null) { + let xlaClusterHistogram = + (node.parentNode).xlaClusterHistogram; + xlaClusterHistogram[leaf.xlaCluster] = + (xlaClusterHistogram[leaf.xlaCluster] || 0) + 1; + } if (leaf.stats != null) { node.parentNode.stats.combine(leaf.stats); } @@ -592,6 +604,10 @@ function addNodes(h: Hierarchy, graph: SlimGraph) { parent.deviceHistogram[node.device] = (parent.deviceHistogram[node.device] || 0) + 1; } + if (node.xlaCluster != null) { + parent.xlaClusterHistogram[node.xlaCluster] = + (parent.xlaClusterHistogram[node.xlaCluster] || 0) + 1; + } // Increment parents appropriate compatibility count if (node.compatible) { @@ -827,6 +843,10 @@ function groupSeries(metanode: Metanode, hierarchy: Hierarchy, seriesNode.deviceHistogram[child.device] = (seriesNode.deviceHistogram[child.device] || 0) + 1; } + if (child.xlaCluster != null) { + seriesNode.xlaClusterHistogram[child.xlaCluster] = + (seriesNode.xlaClusterHistogram[child.xlaCluster] || 0) + 1; + } // Increment parents appropriate compatibility count if (child.compatible) { diff --git a/tensorboard/plugins/graph/tf_graph_common/node.ts b/tensorboard/plugins/graph/tf_graph_common/node.ts index a8f0b10e10..69bf54ea9f 100644 --- a/tensorboard/plugins/graph/tf_graph_common/node.ts +++ b/tensorboard/plugins/graph/tf_graph_common/node.ts @@ -612,6 +612,32 @@ function position(nodeGroup, d: render.RenderNodeInfo) { export enum ColorBy {STRUCTURE, DEVICE, XLA_CLUSTER, COMPUTE_TIME, MEMORY, OP_COMPATIBILITY}; +function getGradient( + id: string, colors: Array<{color: string, proportion: number}>) { + let escapedId = tf.graph.util.escapeQuerySelector(id); + let gradientDefs = d3.select('svg#svg defs #linearGradients'); + let linearGradient = gradientDefs.select('linearGradient#' + escapedId); + // If the linear gradient is not there yet, create it. + if (linearGradient.size() === 0) { + linearGradient = gradientDefs.append('linearGradient').attr('id', id); + // Re-create the stops of the linear gradient. + linearGradient.selectAll('*').remove(); + let cumulativeProportion = 0; + // For each color, create a stop using the proportion of that device. + _.each(colors, d => { + let color = d.color; + linearGradient.append('stop') + .attr('offset', cumulativeProportion) + .attr('stop-color', color); + linearGradient.append('stop') + .attr('offset', cumulativeProportion + d.proportion) + .attr('stop-color', color); + cumulativeProportion += d.proportion; + }); + } + return `url(#${escapedId})`; +} + /** * Returns the fill color for the node given its state and the 'color by' * option. @@ -650,32 +676,21 @@ export function getFillForNode(templateIndex, colorBy, // Return the hue for unknown device. return colorParams.UNKNOWN; } - let id = renderInfo.node.name; - let escapedId = tf.graph.util.escapeQuerySelector(id); - let gradientDefs = d3.select('svg#svg defs #linearGradients'); - let linearGradient = gradientDefs.select('linearGradient#' + escapedId); - // If the linear gradient is not there yet, create it. - if (linearGradient.size() === 0) { - linearGradient = gradientDefs.append('linearGradient').attr('id', id); - // Re-create the stops of the linear gradient. - linearGradient.selectAll('*').remove(); - let cumulativeProportion = 0; - // For each device, create a stop using the proportion of that device. - _.each(renderInfo.deviceColors, d => { - let color = d.color; - linearGradient.append('stop') - .attr('offset', cumulativeProportion) - .attr('stop-color', color); - linearGradient.append('stop') - .attr('offset', cumulativeProportion + d.proportion) - .attr('stop-color', color); - cumulativeProportion += d.proportion; - }); - } - return isExpanded ? colorParams.EXPANDED_COLOR : `url(#${escapedId})`; + return isExpanded ? + colorParams.EXPANDED_COLOR : + getGradient( + 'device-' + renderInfo.node.name, renderInfo.deviceColors); + case ColorBy.XLA_CLUSTER: - return isExpanded ? colorParams.EXPANDED_COLOR : - renderInfo.xlaClusterColor || colorParams.UNKNOWN; + if (renderInfo.xlaClusterColors == null) { + // Return the hue for unknown xlaCluster. + return colorParams.UNKNOWN; + } + return isExpanded ? + colorParams.EXPANDED_COLOR : + getGradient( + 'xla-' + renderInfo.node.name, renderInfo.xlaClusterColors); + case ColorBy.COMPUTE_TIME: return isExpanded ? colorParams.EXPANDED_COLOR : renderInfo.computeTimeColor || @@ -685,52 +700,14 @@ export function getFillForNode(templateIndex, colorBy, colorParams.EXPANDED_COLOR : renderInfo.memoryColor || colorParams.UNKNOWN; case ColorBy.OP_COMPATIBILITY: - if (renderInfo.node.type === NodeType.OP) { - return ((renderInfo.node).compatible) ? - tf.graph.render.OpNodeColors.COMPATIBLE : - tf.graph.render.OpNodeColors.INCOMPATIBLE; - } else if (renderInfo.node.isGroupNode) { - let node = renderInfo.node; - - let numCompat = node.compatibilityHistogram.compatible; - let numIncompat = node.compatibilityHistogram.incompatible - - if (numCompat == 0 && numIncompat == 0) { - // Return the hue for unknown device. - return colorParams.UNKNOWN; - } - - let id = "op-compat-" + node.name; - let escapedId = tf.graph.util.escapeQuerySelector(id); - let gradientDefs = d3.select('svg#svg defs #linearGradients'); - let linearGradient = gradientDefs.select('linearGradient#' + escapedId); - // If the linear gradient is not there yet, create it. - if (linearGradient.size() === 0) { - let percentValid = numCompat / (numCompat + numIncompat); - linearGradient = gradientDefs.append('linearGradient').attr('id', id); - - // Re-create the stops of the linear gradient. - linearGradient.selectAll('*').remove(); - - linearGradient.append('stop') - .attr('offset', 0) - .attr('stop-color', tf.graph.render.OpNodeColors.COMPATIBLE); - linearGradient.append('stop') - .attr('offset', percentValid) - .attr('stop-color', tf.graph.render.OpNodeColors.COMPATIBLE); - linearGradient.append('stop') - .attr('offset', percentValid) - .attr('stop-color', tf.graph.render.OpNodeColors.INCOMPATIBLE); - linearGradient.append('stop') - .attr('offset', 1) - .attr('stop-color', tf.graph.render.OpNodeColors.INCOMPATIBLE); - } - - return isExpanded ? colorParams.EXPANDED_COLOR : `url(#${escapedId})`; - } else { - // All other nodes will be set to the default color - return colorParams.DEFAULT_FILL; + if (renderInfo.compatibilityColors == null) { + // Return the hue for unknown compatibility info. + return colorParams.UNKNOWN; } + return isExpanded ? colorParams.EXPANDED_COLOR : + getGradient( + 'op-compat-' + renderInfo.node.name, + renderInfo.compatibilityColors); default: throw new Error('Unknown case to color nodes by'); } diff --git a/tensorboard/plugins/graph/tf_graph_common/render.ts b/tensorboard/plugins/graph/tf_graph_common/render.ts index 3bca699ef3..af45729b2b 100644 --- a/tensorboard/plugins/graph/tf_graph_common/render.ts +++ b/tensorboard/plugins/graph/tf_graph_common/render.ts @@ -293,6 +293,24 @@ export class RenderGraphInfo { return this.hierarchy.node(nodeName); } + private colorHistogram( + histogram: {[name: string]: number}, + colors: d3.ScaleOrdinal): + Array<{color: string, proportion: number}> { + let pairs = _.pairs(histogram); + if (pairs.length > 0) { + // Compute the total # of items. + let numItems = _.sum(pairs, _.last); + return _.map(pairs, pair => ({ + color: colors(pair[0]), + // Normalize to a proportion of total # of items. + proportion: pair[1] / numItems + })); + } + console.info('no pairs found!'); + return null; + } + /** * Get a previously created RenderNodeInfo for the specified node name, * or create one if it hasn't been created yet. @@ -326,38 +344,53 @@ export class RenderGraphInfo { this.computeTimeScale(node.stats.getTotalMicros()); } - if (!node.isGroupNode) { - let clusterName = (node as OpNode).xlaCluster; - if (clusterName) { - renderInfo.xlaClusterColor = this.xlaClusterColorMap(clusterName); - } - } - // We only fade nodes when we're displaying stats. renderInfo.isFadedOut = this.displayingStats && !tf.graph.util.hasDisplayableNodeStats(node.stats); + var deviceHistogram = null; + var xlaClusterHistogram = null; + var opCompatibility = null; if (node.isGroupNode) { - // Make a list of tuples (device, proportion), where proportion - // is the fraction of op nodes that have that device. - let pairs = _.pairs((node).deviceHistogram); - if (pairs.length > 0) { - // Compute the total # of devices. - let numDevices = _.sum(pairs, _.last); - renderInfo.deviceColors = _.map(pairs, pair => ({ - color: this.deviceColorMap(pair[0]), - // Normalize to a proportion of total # of devices. - proportion: pair[1] / numDevices - })); + deviceHistogram = (node).deviceHistogram; + xlaClusterHistogram = (node).xlaClusterHistogram; + let compat = (node).compatibilityHistogram.compatible; + let incompat = (node).compatibilityHistogram.incompatible; + if (compat != 0 || incompat != 0) { + opCompatibility = compat / (compat + incompat); } } else { let device = (renderInfo.node).device; if (device) { - renderInfo.deviceColors = [{ - color: this.deviceColorMap(device), - proportion: 1.0 - }]; + deviceHistogram = {[device]: 1}; + } + let xlaCluster = (renderInfo.node).xlaCluster; + if (xlaCluster) { + xlaClusterHistogram = {[xlaCluster]: 1}; } + if (renderInfo.node.type === NodeType.OP) { + opCompatibility = (renderInfo.node).compatible ? 1 : 0; + } + } + if (deviceHistogram) { + renderInfo.deviceColors = + this.colorHistogram(deviceHistogram, this.deviceColorMap); + } + if (xlaClusterHistogram) { + renderInfo.xlaClusterColors = + this.colorHistogram(xlaClusterHistogram, this.xlaClusterColorMap); + } + if (opCompatibility != null) { + renderInfo.compatibilityColors = [ + { + color: tf.graph.render.OpNodeColors.COMPATIBLE, + proportion: opCompatibility + }, + { + color: tf.graph.render.OpNodeColors.INCOMPATIBLE, + proportion: 1 - opCompatibility + } + ]; } return this.index[nodeName]; @@ -578,6 +611,8 @@ export class RenderGraphInfo { newMetanode.templateId = libraryMetanode.templateId; newMetanode.opHistogram = _.clone(libraryMetanode.opHistogram); newMetanode.deviceHistogram = _.clone(libraryMetanode.deviceHistogram); + newMetanode.xlaClusterHistogram = + _.clone(libraryMetanode.xlaClusterHistogram); newMetanode.hasNonControlEdges = libraryMetanode.hasNonControlEdges; newMetanode.include = libraryMetanode.include; newMetanode.nodeAttributes = _.clone(libraryMetanode.nodeAttributes); @@ -704,7 +739,7 @@ export class RenderGraphInfo { // metagraph. const originalMetaEdges = this.hierarchy.getPredecessors( opNodeToReplace.name); - + // Find the metaedge that the input index corresponds to. // A metaedge may correspond to several edges. For instance, // an edge may enter a series node. @@ -1568,9 +1603,18 @@ export class RenderNodeInfo { deviceColors: Array<{color: string, proportion: number}>; /** - * Color according to the XLA cluster of this node. + * List of (color, proportion) tuples based on the proportion of xlaClusters + * of its children. If this node is an op node, this list will have only one + * color with proportion 1.0. + */ + xlaClusterColors: Array<{color: string, proportion: number}>; + + /** + * List of (color, proportion) tuples based on the proportion of compatible + * nodes of its children. If this node is an op node, this list will have only + * one color with proportion 1.0. */ - xlaClusterColor: string; + compatibilityColors: Array<{color: string, proportion: number}>; /** * Color according to the memory usage of this node. diff --git a/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html b/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html index dcb86e6e8d..5c434f37e3 100644 --- a/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html +++ b/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html @@ -487,22 +487,16 @@