diff --git a/tensorboard/plugins/graph/tf_graph_app/demo/data/graph.pbtxt b/tensorboard/plugins/graph/tf_graph_app/demo/data/graph.pbtxt
index 8b95b258df..4885a8ace7 100644
--- a/tensorboard/plugins/graph/tf_graph_app/demo/data/graph.pbtxt
+++ b/tensorboard/plugins/graph/tf_graph_app/demo/data/graph.pbtxt
@@ -62,8 +62,9 @@ node {
}
}
node {
- name: "Add"
+ name: "0/0/Add"
op: "Add"
+ device: "1"
input: "life"
input: "universe"
attr {
@@ -72,11 +73,78 @@ node {
type: DT_INT32
}
}
+ attr {
+ key: "_XlaCluster"
+ value {
+ s: "cluster_0"
+ }
+ }
+}
+node {
+ name: "0/1/Add"
+ op: "Add"
+ device: "2"
+ input: "life"
+ input: "universe"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_XlaCluster"
+ value {
+ s: "cluster_1"
+ }
+ }
+}
+node {
+ name: "1/Add"
+ op: "Add"
+ device: "tpu"
+ input: "life"
+ input: "universe"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_XlaCluster"
+ value {
+ s: "cluster_2"
+ }
+ }
+}
+node {
+ name: "1/Weird"
+ op: "Weird"
+ device: "tpu"
+ input: "life"
+ input: "universe"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_XlaCluster"
+ value {
+ s: "cluster_3"
+ }
+ }
}
node {
name: "answer"
op: "Add"
- input: "Add"
+ device: "2"
+ input: "0/0/Add"
+ input: "0/1/Add"
+ input: "1/Add"
+ input: "1/Weird"
input: "everything"
attr {
key: "T"
@@ -84,6 +152,12 @@ node {
type: DT_INT32
}
}
+ attr {
+ key: "_XlaCluster"
+ value {
+ s: "cluster_1"
+ }
+ }
}
versions {
producer: 10
diff --git a/tensorboard/plugins/graph/tf_graph_app/demo/index.html b/tensorboard/plugins/graph/tf_graph_app/demo/index.html
index f71feea390..bef8527f0a 100644
--- a/tensorboard/plugins/graph/tf_graph_app/demo/index.html
+++ b/tensorboard/plugins/graph/tf_graph_app/demo/index.html
@@ -23,7 +23,7 @@
/** Make the graph app tall enough so the bottom legend does not overlap with the top. */
tf-graph-app, .container.tf-graph-app {
display: block;
- height: 700px;
+ height: 900px;
}
Answer to the Ultimate Question of Life, the Universe, and Everything
diff --git a/tensorboard/plugins/graph/tf_graph_app/tf-graph-app.html b/tensorboard/plugins/graph/tf_graph_app/tf-graph-app.html
index 1be0e3a983..b02a01d63a 100644
--- a/tensorboard/plugins/graph/tf_graph_app/tf-graph-app.html
+++ b/tensorboard/plugins/graph/tf_graph_app/tf-graph-app.html
@@ -90,12 +90,14 @@
color-by="{{colorBy}}"
render-hierarchy="[[_renderHierarchy]]"
selected-node="{{selectedNode}}"
+ selected-file="{{selectedFile}}"
>
diff --git a/tensorboard/plugins/graph/tf_graph_common/graph.ts b/tensorboard/plugins/graph/tf_graph_common/graph.ts
index 16f7901e21..f715682f74 100644
--- a/tensorboard/plugins/graph/tf_graph_common/graph.ts
+++ b/tensorboard/plugins/graph/tf_graph_common/graph.ts
@@ -295,6 +295,12 @@ export interface GroupNode extends Node {
*/
deviceHistogram: {[device: string]: number};
+ /**
+ * Stores how many times each XLA cluster name appears in its children
+ * op nodes. Used to color group nodes by XLA clusters.
+ */
+ xlaClusterHistogram: {[device: string]: number};
+
/**
* Stores how many ops in sub-graph were compatible and how many are
* incompatible.
@@ -595,6 +601,7 @@ export class MetanodeImpl implements Metanode {
templateId: string;
opHistogram: {[op: string]: number};
deviceHistogram: {[op: string]: number};
+ xlaClusterHistogram: {[op: string]: number};
compatibilityHistogram: {compatible: number, incompatible: number};
parentNode: Node;
hasNonControlEdges: boolean;
@@ -624,6 +631,7 @@ export class MetanodeImpl implements Metanode {
*/
this.opHistogram = {};
this.deviceHistogram = {};
+ this.xlaClusterHistogram = {};
this.compatibilityHistogram = {compatible: 0, incompatible: 0};
/** unique id for a metanode of similar subgraph */
this.templateId = null;
@@ -831,6 +839,7 @@ class SeriesNodeImpl implements SeriesNode {
bridgegraph: graphlib.Graph;
parentNode: Node;
deviceHistogram: {[op: string]: number};
+ xlaClusterHistogram: {[op: string]: number};
compatibilityHistogram: {compatible: number, incompatible: number};
hasNonControlEdges: boolean;
include: InclusionType;
@@ -859,6 +868,7 @@ class SeriesNodeImpl implements SeriesNode {
this.bridgegraph = null;
this.parentNode = null;
this.deviceHistogram = {};
+ this.xlaClusterHistogram = {};
this.compatibilityHistogram = {compatible: 0, incompatible: 0};
this.hasNonControlEdges = false;
this.include = InclusionType.UNSPECIFIED;
diff --git a/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts b/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts
index c4299df5ee..51d0dd096c 100644
--- a/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts
+++ b/tensorboard/plugins/graph/tf_graph_common/hierarchy.ts
@@ -87,6 +87,7 @@ class HierarchyImpl implements Hierarchy {
this.libraryFunctions = {};
this.templates = null;
this.devices = null;
+ this.xlaClusters = null;
/**
* @type {Object} Dictionary object that maps node name to the node
* (could be op-node, metanode, or series-node)
@@ -475,15 +476,20 @@ export function build(graph: tf.graph.SlimGraph, params: HierarchyParams,
export function joinAndAggregateStats(
h: Hierarchy, stats: tf.graph.proto.StepStats) {
- // Get all the possible device names.
+ // Get all the possible device and XLA cluster names.
let deviceNames = {};
+ let xlaClusterNames = {};
_.each(h.root.leaves(), nodeName => {
let leaf = h.node(nodeName);
if (leaf.device != null) {
deviceNames[leaf.device] = true;
}
+ if (leaf.xlaCluster != null) {
+ xlaClusterNames[leaf.xlaCluster] = true;
+ }
});
h.devices = _.keys(deviceNames);
+ h.xlaClusters = _.keys(xlaClusterNames);
// Reset stats for each group node.
_.each(h.getNodeMap(), (node, nodeName) => {
@@ -502,6 +508,12 @@ export function joinAndAggregateStats(
let deviceHistogram = (node.parentNode).deviceHistogram;
deviceHistogram[leaf.device] = (deviceHistogram[leaf.device] || 0) + 1;
}
+ if (leaf.xlaCluster != null) {
+ let xlaClusterHistogram =
+ (node.parentNode).xlaClusterHistogram;
+ xlaClusterHistogram[leaf.xlaCluster] =
+ (xlaClusterHistogram[leaf.xlaCluster] || 0) + 1;
+ }
if (leaf.stats != null) {
node.parentNode.stats.combine(leaf.stats);
}
@@ -592,6 +604,10 @@ function addNodes(h: Hierarchy, graph: SlimGraph) {
parent.deviceHistogram[node.device] =
(parent.deviceHistogram[node.device] || 0) + 1;
}
+ if (node.xlaCluster != null) {
+ parent.xlaClusterHistogram[node.xlaCluster] =
+ (parent.xlaClusterHistogram[node.xlaCluster] || 0) + 1;
+ }
// Increment parents appropriate compatibility count
if (node.compatible) {
@@ -827,6 +843,10 @@ function groupSeries(metanode: Metanode, hierarchy: Hierarchy,
seriesNode.deviceHistogram[child.device] =
(seriesNode.deviceHistogram[child.device] || 0) + 1;
}
+ if (child.xlaCluster != null) {
+ seriesNode.xlaClusterHistogram[child.xlaCluster] =
+ (seriesNode.xlaClusterHistogram[child.xlaCluster] || 0) + 1;
+ }
// Increment parents appropriate compatibility count
if (child.compatible) {
diff --git a/tensorboard/plugins/graph/tf_graph_common/node.ts b/tensorboard/plugins/graph/tf_graph_common/node.ts
index a8f0b10e10..69bf54ea9f 100644
--- a/tensorboard/plugins/graph/tf_graph_common/node.ts
+++ b/tensorboard/plugins/graph/tf_graph_common/node.ts
@@ -612,6 +612,32 @@ function position(nodeGroup, d: render.RenderNodeInfo) {
export enum ColorBy {STRUCTURE, DEVICE, XLA_CLUSTER, COMPUTE_TIME, MEMORY,
OP_COMPATIBILITY};
+function getGradient(
+ id: string, colors: Array<{color: string, proportion: number}>) {
+ let escapedId = tf.graph.util.escapeQuerySelector(id);
+ let gradientDefs = d3.select('svg#svg defs #linearGradients');
+ let linearGradient = gradientDefs.select('linearGradient#' + escapedId);
+ // If the linear gradient is not there yet, create it.
+ if (linearGradient.size() === 0) {
+ linearGradient = gradientDefs.append('linearGradient').attr('id', id);
+ // Re-create the stops of the linear gradient.
+ linearGradient.selectAll('*').remove();
+ let cumulativeProportion = 0;
+ // For each color, create a stop using the proportion of that device.
+ _.each(colors, d => {
+ let color = d.color;
+ linearGradient.append('stop')
+ .attr('offset', cumulativeProportion)
+ .attr('stop-color', color);
+ linearGradient.append('stop')
+ .attr('offset', cumulativeProportion + d.proportion)
+ .attr('stop-color', color);
+ cumulativeProportion += d.proportion;
+ });
+ }
+ return `url(#${escapedId})`;
+}
+
/**
* Returns the fill color for the node given its state and the 'color by'
* option.
@@ -650,32 +676,21 @@ export function getFillForNode(templateIndex, colorBy,
// Return the hue for unknown device.
return colorParams.UNKNOWN;
}
- let id = renderInfo.node.name;
- let escapedId = tf.graph.util.escapeQuerySelector(id);
- let gradientDefs = d3.select('svg#svg defs #linearGradients');
- let linearGradient = gradientDefs.select('linearGradient#' + escapedId);
- // If the linear gradient is not there yet, create it.
- if (linearGradient.size() === 0) {
- linearGradient = gradientDefs.append('linearGradient').attr('id', id);
- // Re-create the stops of the linear gradient.
- linearGradient.selectAll('*').remove();
- let cumulativeProportion = 0;
- // For each device, create a stop using the proportion of that device.
- _.each(renderInfo.deviceColors, d => {
- let color = d.color;
- linearGradient.append('stop')
- .attr('offset', cumulativeProportion)
- .attr('stop-color', color);
- linearGradient.append('stop')
- .attr('offset', cumulativeProportion + d.proportion)
- .attr('stop-color', color);
- cumulativeProportion += d.proportion;
- });
- }
- return isExpanded ? colorParams.EXPANDED_COLOR : `url(#${escapedId})`;
+ return isExpanded ?
+ colorParams.EXPANDED_COLOR :
+ getGradient(
+ 'device-' + renderInfo.node.name, renderInfo.deviceColors);
+
case ColorBy.XLA_CLUSTER:
- return isExpanded ? colorParams.EXPANDED_COLOR :
- renderInfo.xlaClusterColor || colorParams.UNKNOWN;
+ if (renderInfo.xlaClusterColors == null) {
+ // Return the hue for unknown xlaCluster.
+ return colorParams.UNKNOWN;
+ }
+ return isExpanded ?
+ colorParams.EXPANDED_COLOR :
+ getGradient(
+ 'xla-' + renderInfo.node.name, renderInfo.xlaClusterColors);
+
case ColorBy.COMPUTE_TIME:
return isExpanded ?
colorParams.EXPANDED_COLOR : renderInfo.computeTimeColor ||
@@ -685,52 +700,14 @@ export function getFillForNode(templateIndex, colorBy,
colorParams.EXPANDED_COLOR : renderInfo.memoryColor ||
colorParams.UNKNOWN;
case ColorBy.OP_COMPATIBILITY:
- if (renderInfo.node.type === NodeType.OP) {
- return ((renderInfo.node).compatible) ?
- tf.graph.render.OpNodeColors.COMPATIBLE :
- tf.graph.render.OpNodeColors.INCOMPATIBLE;
- } else if (renderInfo.node.isGroupNode) {
- let node = renderInfo.node;
-
- let numCompat = node.compatibilityHistogram.compatible;
- let numIncompat = node.compatibilityHistogram.incompatible
-
- if (numCompat == 0 && numIncompat == 0) {
- // Return the hue for unknown device.
- return colorParams.UNKNOWN;
- }
-
- let id = "op-compat-" + node.name;
- let escapedId = tf.graph.util.escapeQuerySelector(id);
- let gradientDefs = d3.select('svg#svg defs #linearGradients');
- let linearGradient = gradientDefs.select('linearGradient#' + escapedId);
- // If the linear gradient is not there yet, create it.
- if (linearGradient.size() === 0) {
- let percentValid = numCompat / (numCompat + numIncompat);
- linearGradient = gradientDefs.append('linearGradient').attr('id', id);
-
- // Re-create the stops of the linear gradient.
- linearGradient.selectAll('*').remove();
-
- linearGradient.append('stop')
- .attr('offset', 0)
- .attr('stop-color', tf.graph.render.OpNodeColors.COMPATIBLE);
- linearGradient.append('stop')
- .attr('offset', percentValid)
- .attr('stop-color', tf.graph.render.OpNodeColors.COMPATIBLE);
- linearGradient.append('stop')
- .attr('offset', percentValid)
- .attr('stop-color', tf.graph.render.OpNodeColors.INCOMPATIBLE);
- linearGradient.append('stop')
- .attr('offset', 1)
- .attr('stop-color', tf.graph.render.OpNodeColors.INCOMPATIBLE);
- }
-
- return isExpanded ? colorParams.EXPANDED_COLOR : `url(#${escapedId})`;
- } else {
- // All other nodes will be set to the default color
- return colorParams.DEFAULT_FILL;
+ if (renderInfo.compatibilityColors == null) {
+ // Return the hue for unknown compatibility info.
+ return colorParams.UNKNOWN;
}
+ return isExpanded ? colorParams.EXPANDED_COLOR :
+ getGradient(
+ 'op-compat-' + renderInfo.node.name,
+ renderInfo.compatibilityColors);
default:
throw new Error('Unknown case to color nodes by');
}
diff --git a/tensorboard/plugins/graph/tf_graph_common/render.ts b/tensorboard/plugins/graph/tf_graph_common/render.ts
index 3bca699ef3..af45729b2b 100644
--- a/tensorboard/plugins/graph/tf_graph_common/render.ts
+++ b/tensorboard/plugins/graph/tf_graph_common/render.ts
@@ -293,6 +293,24 @@ export class RenderGraphInfo {
return this.hierarchy.node(nodeName);
}
+ private colorHistogram(
+ histogram: {[name: string]: number},
+ colors: d3.ScaleOrdinal):
+ Array<{color: string, proportion: number}> {
+ let pairs = _.pairs(histogram);
+ if (pairs.length > 0) {
+ // Compute the total # of items.
+ let numItems = _.sum(pairs, _.last);
+ return _.map(pairs, pair => ({
+ color: colors(pair[0]),
+ // Normalize to a proportion of total # of items.
+ proportion: pair[1] / numItems
+ }));
+ }
+ console.info('no pairs found!');
+ return null;
+ }
+
/**
* Get a previously created RenderNodeInfo for the specified node name,
* or create one if it hasn't been created yet.
@@ -326,38 +344,53 @@ export class RenderGraphInfo {
this.computeTimeScale(node.stats.getTotalMicros());
}
- if (!node.isGroupNode) {
- let clusterName = (node as OpNode).xlaCluster;
- if (clusterName) {
- renderInfo.xlaClusterColor = this.xlaClusterColorMap(clusterName);
- }
- }
-
// We only fade nodes when we're displaying stats.
renderInfo.isFadedOut = this.displayingStats &&
!tf.graph.util.hasDisplayableNodeStats(node.stats);
+ var deviceHistogram = null;
+ var xlaClusterHistogram = null;
+ var opCompatibility = null;
if (node.isGroupNode) {
- // Make a list of tuples (device, proportion), where proportion
- // is the fraction of op nodes that have that device.
- let pairs = _.pairs((node).deviceHistogram);
- if (pairs.length > 0) {
- // Compute the total # of devices.
- let numDevices = _.sum(pairs, _.last);
- renderInfo.deviceColors = _.map(pairs, pair => ({
- color: this.deviceColorMap(pair[0]),
- // Normalize to a proportion of total # of devices.
- proportion: pair[1] / numDevices
- }));
+ deviceHistogram = (node).deviceHistogram;
+ xlaClusterHistogram = (node).xlaClusterHistogram;
+ let compat = (node).compatibilityHistogram.compatible;
+ let incompat = (node).compatibilityHistogram.incompatible;
+ if (compat != 0 || incompat != 0) {
+ opCompatibility = compat / (compat + incompat);
}
} else {
let device = (renderInfo.node).device;
if (device) {
- renderInfo.deviceColors = [{
- color: this.deviceColorMap(device),
- proportion: 1.0
- }];
+ deviceHistogram = {[device]: 1};
+ }
+ let xlaCluster = (renderInfo.node).xlaCluster;
+ if (xlaCluster) {
+ xlaClusterHistogram = {[xlaCluster]: 1};
}
+ if (renderInfo.node.type === NodeType.OP) {
+ opCompatibility = (renderInfo.node).compatible ? 1 : 0;
+ }
+ }
+ if (deviceHistogram) {
+ renderInfo.deviceColors =
+ this.colorHistogram(deviceHistogram, this.deviceColorMap);
+ }
+ if (xlaClusterHistogram) {
+ renderInfo.xlaClusterColors =
+ this.colorHistogram(xlaClusterHistogram, this.xlaClusterColorMap);
+ }
+ if (opCompatibility != null) {
+ renderInfo.compatibilityColors = [
+ {
+ color: tf.graph.render.OpNodeColors.COMPATIBLE,
+ proportion: opCompatibility
+ },
+ {
+ color: tf.graph.render.OpNodeColors.INCOMPATIBLE,
+ proportion: 1 - opCompatibility
+ }
+ ];
}
return this.index[nodeName];
@@ -578,6 +611,8 @@ export class RenderGraphInfo {
newMetanode.templateId = libraryMetanode.templateId;
newMetanode.opHistogram = _.clone(libraryMetanode.opHistogram);
newMetanode.deviceHistogram = _.clone(libraryMetanode.deviceHistogram);
+ newMetanode.xlaClusterHistogram =
+ _.clone(libraryMetanode.xlaClusterHistogram);
newMetanode.hasNonControlEdges = libraryMetanode.hasNonControlEdges;
newMetanode.include = libraryMetanode.include;
newMetanode.nodeAttributes = _.clone(libraryMetanode.nodeAttributes);
@@ -704,7 +739,7 @@ export class RenderGraphInfo {
// metagraph.
const originalMetaEdges = this.hierarchy.getPredecessors(
opNodeToReplace.name);
-
+
// Find the metaedge that the input index corresponds to.
// A metaedge may correspond to several edges. For instance,
// an edge may enter a series node.
@@ -1568,9 +1603,18 @@ export class RenderNodeInfo {
deviceColors: Array<{color: string, proportion: number}>;
/**
- * Color according to the XLA cluster of this node.
+ * List of (color, proportion) tuples based on the proportion of xlaClusters
+ * of its children. If this node is an op node, this list will have only one
+ * color with proportion 1.0.
+ */
+ xlaClusterColors: Array<{color: string, proportion: number}>;
+
+ /**
+ * List of (color, proportion) tuples based on the proportion of compatible
+ * nodes of its children. If this node is an op node, this list will have only
+ * one color with proportion 1.0.
*/
- xlaClusterColor: string;
+ compatibilityColors: Array<{color: string, proportion: number}>;
/**
* Color according to the memory usage of this node.
diff --git a/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html b/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html
index dcb86e6e8d..5c434f37e3 100644
--- a/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html
+++ b/tensorboard/plugins/graph/tf_graph_controls/tf-graph-controls.html
@@ -487,22 +487,16 @@
-
-
-
-
-
- |
-
- |
-
- [[item.xla_cluster]]
- |
-
-
-
-
-
+
+
+
+
+ [[item.xla_cluster]]
+
+