Skip to content

Commit

Permalink
add utitlity to visualize datapipe graphs (#330)
Browse files Browse the repository at this point in the history
Summary:
Closes #299.

Pull Request resolved: #330

Reviewed By: ejguan

Differential Revision: D36990978

Pulled By: NivekT

fbshipit-source-id: 423a52e6d59da59826d1422676d28287537dd902
  • Loading branch information
pmeier authored and NivekT committed Jun 9, 2022
1 parent 5d0675c commit ce085e6
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ jobs:
set -eux
python -m pip install requests --user
python -m pip install mypy==0.812 --user
python -m pip install graphviz --user
- name: Build TorchData
run: |
set -eux
Expand Down
5 changes: 5 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,8 @@ def setup(app):
obj.__name__ = name

app.connect("autodoc-process-signature", process_signature)


intersphinx_mapping = {
"graphviz": ("https://graphviz.readthedocs.io/en/stable/", None),
}
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ File Object and Stream Utility
:template: datapipe.rst

StreamWrapper
to_graph


DataLoader
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@ ignore_missing_imports = True

[mypy-pyarrow.*]
ignore_missing_imports = True

[mypy-graphviz.*]
ignore_missing_imports = True
4 changes: 3 additions & 1 deletion torchdata/datapipes/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@

from torch.utils.data.datapipes.utils.common import StreamWrapper

__all__ = ["StreamWrapper"]
from ._visualization import to_graph

__all__ = ["StreamWrapper", "to_graph"]
171 changes: 171 additions & 0 deletions torchdata/datapipes/utils/_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import itertools
from collections import defaultdict

from typing import Optional, Set, TYPE_CHECKING

from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, IterDataPipe
from torch.utils.data.graph import traverse

if TYPE_CHECKING:
import graphviz


class Node:
def __init__(self, dp, *, name=None):
self.dp = dp
self.name = name or type(dp).__name__.replace("IterDataPipe", "")
self.childs = set()
self.parents = set()

def add_child(self, child):
self.childs.add(child)
child.parents.add(self)

def remove_child(self, child):
self.childs.remove(child)
child.parents.remove(self)

def add_parent(self, parent):
self.parents.add(parent)
parent.childs.add(self)

def remove_parent(self, parent):
self.parents.remove(parent)
parent.childs.remove(self)

def __eq__(self, other):
if not isinstance(other, Node):
return NotImplemented

return hash(self) == hash(other)

def __hash__(self):
return hash(self.dp)

def __str__(self):
return self.name

def __repr__(self):
return f"{self}-{hash(self)}"


def to_nodes(dp, *, debug: bool) -> Set[Node]:
def recurse(dp_graph, child=None):
for dp_node, dp_parents in dp_graph.items():
node = Node(dp_node)
if child is not None:
node.add_child(child)
yield node
yield from recurse(dp_parents, child=node)

def aggregate(nodes):
groups = defaultdict(list)
for node in nodes:
groups[node].append(node)

nodes = set()
for node, group in groups.items():
if len(group) == 1:
nodes.add(node)
continue

aggregated_node = Node(node.dp)

for duplicate_node in group:
for child in duplicate_node.childs.copy():
duplicate_node.remove_child(child)
aggregated_node.add_child(child)

for parent in duplicate_node.parents.copy():
duplicate_node.remove_parent(parent)
aggregated_node.add_parent(parent)

nodes.add(aggregated_node)

if debug:
return nodes

child_dp_nodes = set(
itertools.chain.from_iterable(node.parents for node in nodes if isinstance(node.dp, _ChildDataPipe))
)

if not child_dp_nodes:
return nodes

for node in child_dp_nodes:
fixed_parent_node = Node(
type(str(node).lstrip("_"), (IterDataPipe,), dict(dp=node.dp, childs=node.childs))()
)
nodes.remove(node)
nodes.add(fixed_parent_node)

for parent in node.parents.copy():
node.remove_parent(parent)
fixed_parent_node.add_parent(parent)

for child in node.childs:
nodes.remove(child)
for actual_child in child.childs.copy():
actual_child.remove_parent(child)
actual_child.add_parent(fixed_parent_node)

return nodes

return aggregate(recurse(traverse(dp, only_datapipe=True)))


def to_graph(dp, *, debug: bool = False) -> "graphviz.Digraph":
"""Turns a datapipe into a :class:`graphviz.Digraph` representing the graph of the datapipe.
.. note::
The package :mod:`graphviz` is required to use this function.
.. note::
The most common interfaces for the returned graph object are:
- :meth:`~graphviz.Digraph.render`: Save the graph to a file.
- :meth:`~graphviz.Digraph.view`: Open the graph in a viewer.
Args:
dp: Datapipe.
debug (bool): If ``True``, renders internal datapipes that are usually hidden from the user. Defaults to
``False``.
"""
try:
import graphviz
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The package `graphviz` is required to be installed to use this function. "
"Please `pip install graphviz` or `conda install -c conda-forge graphviz`."
) from None

# The graph style as well as the color scheme below was copied from https://github.com/szagoruyko/pytorchviz/
# https://github.com/szagoruyko/pytorchviz/blob/0adcd83af8aa7ab36d6afd139cabbd9df598edb7/torchviz/dot.py#L78-L85
node_attr = dict(
style="filled",
shape="box",
align="left",
fontsize="10",
ranksep="0.1",
height="0.2",
fontname="monospace",
)
graph = graphviz.Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))

for node in to_nodes(dp, debug=debug):
fillcolor: Optional[str]
if not node.parents:
fillcolor = "lightblue"
elif not node.childs:
fillcolor = "darkolivegreen1"
else:
fillcolor = None

graph.node(name=repr(node), label=str(node), fillcolor=fillcolor)

for child in node.childs:
graph.edge(repr(node), repr(child))

return graph

0 comments on commit ce085e6

Please sign in to comment.