Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions tensorboard/plugins/graph/tf_graph_app/demo/data/graph.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ node {
}
}
node {
name: "Add"
name: "0/0/Add"
op: "Add"
device: "1"
input: "life"
input: "universe"
attr {
Expand All @@ -72,18 +73,91 @@ node {
type: DT_INT32
}
}
attr {
key: "_XlaCluster"
value {
s: "cluster_0"
}
}
}
node {
name: "0/1/Add"
op: "Add"
device: "2"
input: "life"
input: "universe"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "_XlaCluster"
value {
s: "cluster_1"
}
}
}
node {
name: "1/Add"
op: "Add"
device: "tpu"
input: "life"
input: "universe"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "_XlaCluster"
value {
s: "cluster_2"
}
}
}
node {
name: "1/Weird"
op: "Weird"
device: "tpu"
input: "life"
input: "universe"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "_XlaCluster"
value {
s: "cluster_3"
}
}
}
node {
name: "answer"
op: "Add"
input: "Add"
device: "2"
input: "0/0/Add"
input: "0/1/Add"
input: "1/Add"
input: "1/Weird"
input: "everything"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "_XlaCluster"
value {
s: "cluster_1"
}
}
}
versions {
producer: 10
Expand Down
2 changes: 1 addition & 1 deletion tensorboard/plugins/graph/tf_graph_app/demo/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
/** Make the graph app tall enough so the bottom legend does not overlap with the top. */
tf-graph-app, .container.tf-graph-app {
display: block;
height: 700px;
height: 900px;
}
</style>
<h3>Answer to the Ultimate Question of Life, the Universe, and Everything</h3>
Expand Down
2 changes: 2 additions & 0 deletions tensorboard/plugins/graph/tf_graph_app/tf-graph-app.html
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,14 @@
color-by="{{colorBy}}"
render-hierarchy="[[_renderHierarchy]]"
selected-node="{{selectedNode}}"
selected-file="{{selectedFile}}"
></tf-graph-controls>
<tf-graph-loader id="loader"
out-graph-hierarchy="{{graphHierarchy}}"
out-graph="{{graph}}"
out-stats="{{stats}}"
progress="{{_progress}}"
selected-file="[[selectedFile]]"
></tf-graph-loader>
</div>
<div class="main">
Expand Down
10 changes: 10 additions & 0 deletions tensorboard/plugins/graph/tf_graph_common/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ export interface GroupNode extends Node {
*/
deviceHistogram: {[device: string]: number};

/**
* Stores how many times each XLA cluster name appears in its children
* op nodes. Used to color group nodes by XLA clusters.
*/
xlaClusterHistogram: {[device: string]: number};

/**
* Stores how many ops in sub-graph were compatible and how many are
* incompatible.
Expand Down Expand Up @@ -595,6 +601,7 @@ export class MetanodeImpl implements Metanode {
templateId: string;
opHistogram: {[op: string]: number};
deviceHistogram: {[op: string]: number};
xlaClusterHistogram: {[op: string]: number};
compatibilityHistogram: {compatible: number, incompatible: number};
parentNode: Node;
hasNonControlEdges: boolean;
Expand Down Expand Up @@ -624,6 +631,7 @@ export class MetanodeImpl implements Metanode {
*/
this.opHistogram = {};
this.deviceHistogram = {};
this.xlaClusterHistogram = {};
this.compatibilityHistogram = {compatible: 0, incompatible: 0};
/** unique id for a metanode of similar subgraph */
this.templateId = null;
Expand Down Expand Up @@ -831,6 +839,7 @@ class SeriesNodeImpl implements SeriesNode {
bridgegraph: graphlib.Graph<GroupNode|OpNode, Metaedge>;
parentNode: Node;
deviceHistogram: {[op: string]: number};
xlaClusterHistogram: {[op: string]: number};
compatibilityHistogram: {compatible: number, incompatible: number};
hasNonControlEdges: boolean;
include: InclusionType;
Expand Down Expand Up @@ -859,6 +868,7 @@ class SeriesNodeImpl implements SeriesNode {
this.bridgegraph = null;
this.parentNode = null;
this.deviceHistogram = {};
this.xlaClusterHistogram = {};
this.compatibilityHistogram = {compatible: 0, incompatible: 0};
this.hasNonControlEdges = false;
this.include = InclusionType.UNSPECIFIED;
Expand Down
22 changes: 21 additions & 1 deletion tensorboard/plugins/graph/tf_graph_common/hierarchy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class HierarchyImpl implements Hierarchy {
this.libraryFunctions = {};
this.templates = null;
this.devices = null;
this.xlaClusters = null;
/**
* @type {Object} Dictionary object that maps node name to the node
* (could be op-node, metanode, or series-node)
Expand Down Expand Up @@ -475,15 +476,20 @@ export function build(graph: tf.graph.SlimGraph, params: HierarchyParams,

export function joinAndAggregateStats(
h: Hierarchy, stats: tf.graph.proto.StepStats) {
// Get all the possible device names.
// Get all the possible device and XLA cluster names.
let deviceNames = {};
let xlaClusterNames = {};
_.each(h.root.leaves(), nodeName => {
let leaf = <OpNode> h.node(nodeName);
if (leaf.device != null) {
deviceNames[leaf.device] = true;
}
if (leaf.xlaCluster != null) {
xlaClusterNames[leaf.xlaCluster] = true;
}
});
h.devices = _.keys(deviceNames);
h.xlaClusters = _.keys(xlaClusterNames);

// Reset stats for each group node.
_.each(h.getNodeMap(), (node, nodeName) => {
Expand All @@ -502,6 +508,12 @@ export function joinAndAggregateStats(
let deviceHistogram = (<GroupNode>node.parentNode).deviceHistogram;
deviceHistogram[leaf.device] = (deviceHistogram[leaf.device] || 0) + 1;
}
if (leaf.xlaCluster != null) {
let xlaClusterHistogram =
(<GroupNode>node.parentNode).xlaClusterHistogram;
xlaClusterHistogram[leaf.xlaCluster] =
(xlaClusterHistogram[leaf.xlaCluster] || 0) + 1;
}
if (leaf.stats != null) {
node.parentNode.stats.combine(leaf.stats);
}
Expand Down Expand Up @@ -592,6 +604,10 @@ function addNodes(h: Hierarchy, graph: SlimGraph) {
parent.deviceHistogram[node.device] =
(parent.deviceHistogram[node.device] || 0) + 1;
}
if (node.xlaCluster != null) {
parent.xlaClusterHistogram[node.xlaCluster] =
(parent.xlaClusterHistogram[node.xlaCluster] || 0) + 1;
}

// Increment parents appropriate compatibility count
if (node.compatible) {
Expand Down Expand Up @@ -827,6 +843,10 @@ function groupSeries(metanode: Metanode, hierarchy: Hierarchy,
seriesNode.deviceHistogram[child.device] =
(seriesNode.deviceHistogram[child.device] || 0) + 1;
}
if (child.xlaCluster != null) {
seriesNode.xlaClusterHistogram[child.xlaCluster] =
(seriesNode.xlaClusterHistogram[child.xlaCluster] || 0) + 1;
}

// Increment parents appropriate compatibility count
if (child.compatible) {
Expand Down
117 changes: 47 additions & 70 deletions tensorboard/plugins/graph/tf_graph_common/node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,32 @@ function position(nodeGroup, d: render.RenderNodeInfo) {
export enum ColorBy {STRUCTURE, DEVICE, XLA_CLUSTER, COMPUTE_TIME, MEMORY,
OP_COMPATIBILITY};

function getGradient(
id: string, colors: Array<{color: string, proportion: number}>) {
let escapedId = tf.graph.util.escapeQuerySelector(id);
let gradientDefs = d3.select('svg#svg defs #linearGradients');
let linearGradient = gradientDefs.select('linearGradient#' + escapedId);
// If the linear gradient is not there yet, create it.
if (linearGradient.size() === 0) {
linearGradient = gradientDefs.append('linearGradient').attr('id', id);
// Re-create the stops of the linear gradient.
linearGradient.selectAll('*').remove();
let cumulativeProportion = 0;
// For each color, create a stop using the proportion of that device.
_.each(colors, d => {
let color = d.color;
linearGradient.append('stop')
.attr('offset', cumulativeProportion)
.attr('stop-color', color);
linearGradient.append('stop')
.attr('offset', cumulativeProportion + d.proportion)
.attr('stop-color', color);
cumulativeProportion += d.proportion;
});
}
return `url(#${escapedId})`;
}

/**
* Returns the fill color for the node given its state and the 'color by'
* option.
Expand Down Expand Up @@ -650,32 +676,21 @@ export function getFillForNode(templateIndex, colorBy,
// Return the hue for unknown device.
return colorParams.UNKNOWN;
}
let id = renderInfo.node.name;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to also deduplicate the code for the TPU compatibility logic below - it looks like it was essentially copy-pasted between the device color and TPU compat color cases previously, so now that you've factored it out it would ideally be factored out in all three places.

I think it should be pretty easy to adapt, but if you'd rather not bother that's fine and one of us can clean it up later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

let escapedId = tf.graph.util.escapeQuerySelector(id);
let gradientDefs = d3.select('svg#svg defs #linearGradients');
let linearGradient = gradientDefs.select('linearGradient#' + escapedId);
// If the linear gradient is not there yet, create it.
if (linearGradient.size() === 0) {
linearGradient = gradientDefs.append('linearGradient').attr('id', id);
// Re-create the stops of the linear gradient.
linearGradient.selectAll('*').remove();
let cumulativeProportion = 0;
// For each device, create a stop using the proportion of that device.
_.each(renderInfo.deviceColors, d => {
let color = d.color;
linearGradient.append('stop')
.attr('offset', cumulativeProportion)
.attr('stop-color', color);
linearGradient.append('stop')
.attr('offset', cumulativeProportion + d.proportion)
.attr('stop-color', color);
cumulativeProportion += d.proportion;
});
}
return isExpanded ? colorParams.EXPANDED_COLOR : `url(#${escapedId})`;
return isExpanded ?
colorParams.EXPANDED_COLOR :
getGradient(
'device-' + renderInfo.node.name, renderInfo.deviceColors);

case ColorBy.XLA_CLUSTER:
return isExpanded ? colorParams.EXPANDED_COLOR :
renderInfo.xlaClusterColor || colorParams.UNKNOWN;
if (renderInfo.xlaClusterColors == null) {
// Return the hue for unknown xlaCluster.
return colorParams.UNKNOWN;
}
return isExpanded ?
colorParams.EXPANDED_COLOR :
getGradient(
'xla-' + renderInfo.node.name, renderInfo.xlaClusterColors);

case ColorBy.COMPUTE_TIME:
return isExpanded ?
colorParams.EXPANDED_COLOR : renderInfo.computeTimeColor ||
Expand All @@ -685,52 +700,14 @@ export function getFillForNode(templateIndex, colorBy,
colorParams.EXPANDED_COLOR : renderInfo.memoryColor ||
colorParams.UNKNOWN;
case ColorBy.OP_COMPATIBILITY:
if (renderInfo.node.type === NodeType.OP) {
return ((<OpNode>renderInfo.node).compatible) ?
tf.graph.render.OpNodeColors.COMPATIBLE :
tf.graph.render.OpNodeColors.INCOMPATIBLE;
} else if (renderInfo.node.isGroupNode) {
let node = <GroupNode>renderInfo.node;

let numCompat = node.compatibilityHistogram.compatible;
let numIncompat = node.compatibilityHistogram.incompatible

if (numCompat == 0 && numIncompat == 0) {
// Return the hue for unknown device.
return colorParams.UNKNOWN;
}

let id = "op-compat-" + node.name;
let escapedId = tf.graph.util.escapeQuerySelector(id);
let gradientDefs = d3.select('svg#svg defs #linearGradients');
let linearGradient = gradientDefs.select('linearGradient#' + escapedId);
// If the linear gradient is not there yet, create it.
if (linearGradient.size() === 0) {
let percentValid = numCompat / (numCompat + numIncompat);
linearGradient = gradientDefs.append('linearGradient').attr('id', id);

// Re-create the stops of the linear gradient.
linearGradient.selectAll('*').remove();

linearGradient.append('stop')
.attr('offset', 0)
.attr('stop-color', tf.graph.render.OpNodeColors.COMPATIBLE);
linearGradient.append('stop')
.attr('offset', percentValid)
.attr('stop-color', tf.graph.render.OpNodeColors.COMPATIBLE);
linearGradient.append('stop')
.attr('offset', percentValid)
.attr('stop-color', tf.graph.render.OpNodeColors.INCOMPATIBLE);
linearGradient.append('stop')
.attr('offset', 1)
.attr('stop-color', tf.graph.render.OpNodeColors.INCOMPATIBLE);
}

return isExpanded ? colorParams.EXPANDED_COLOR : `url(#${escapedId})`;
} else {
// All other nodes will be set to the default color
return colorParams.DEFAULT_FILL;
if (renderInfo.compatibilityColors == null) {
// Return the hue for unknown compatibility info.
return colorParams.UNKNOWN;
}
return isExpanded ? colorParams.EXPANDED_COLOR :
getGradient(
'op-compat-' + renderInfo.node.name,
renderInfo.compatibilityColors);
default:
throw new Error('Unknown case to color nodes by');
}
Expand Down
Loading