Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve graph visualization documentation #504

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/_templates/function.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}


{{ name | underline }}

.. autofunction:: {{ name }}
16 changes: 15 additions & 1 deletion docs/source/torchdata.datapipes.utils.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
Utility Functions
===========================

..
Comment: the next section will become "DataPipe Graph Visualization and Linter" when linter is added.

DataPipe Graph Visualization
-------------------------------------
.. currentmodule:: torchdata.datapipes.utils

.. autosummary::
:nosignatures:
:toctree: generated/
:template: function.rst

to_graph


File Object and Stream Utility
-------------------------------------

Expand All @@ -12,7 +27,6 @@ File Object and Stream Utility
:template: datapipe.rst

StreamWrapper
to_graph


DataLoader
Expand Down
26 changes: 22 additions & 4 deletions torchdata/datapipes/utils/_visualization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
from collections import defaultdict

Expand Down Expand Up @@ -115,7 +121,8 @@ def aggregate(nodes):


def to_graph(dp, *, debug: bool = False) -> "graphviz.Digraph":
"""Turns a datapipe into a :class:`graphviz.Digraph` representing the graph of the datapipe.
"""Visualizes a DataPipe by returning a :class:`graphviz.Digraph`, which is a graph of the data pipeline.
This allows you to visually inspect all the transformation that takes place in your DataPipes.

.. note::

Expand All @@ -129,9 +136,20 @@ def to_graph(dp, *, debug: bool = False) -> "graphviz.Digraph":
- :meth:`~graphviz.Digraph.view`: Open the graph in a viewer.

Args:
dp: Datapipe.
debug (bool): If ``True``, renders internal datapipes that are usually hidden from the user. Defaults to
``False``.
dp: DataPipe that you would like to visualize (generally the last one in a chain of DataPipes).
debug (bool): If ``True``, renders internal datapipes that are usually hidden from the user
(such as ``ChildDataPipe`` of `demux` and `fork`). Defaults to ``False``.

Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> from torchdata.datapipes.utils import to_graph
>>> dp = IterableWrapper(range(10))
>>> dp1, dp2 = dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
>>> dp1 = dp1.map(lambda x: x + 1)
>>> dp2 = dp2.filter(lambda _: True)
>>> dp3 = dp1.zip(dp2).map(lambda t: t[0] + t[1])
>>> g = to_graph(dp3)
>>> g.view() # This will open the graph in a viewer
"""
try:
import graphviz
Expand Down