Skip to content

Commit

Permalink
tutorial and doc
Browse files Browse the repository at this point in the history
  • Loading branch information
chiwwang committed Oct 25, 2021
1 parent fb2cf61 commit 9955d7c
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 128 deletions.
10 changes: 10 additions & 0 deletions docs/reference/api/python/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ tvm.contrib.random
.. automodule:: tvm.contrib.random
:members:

tvm.contrib.relay_viz
~~~~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.relay_viz
:members: RelayVisualizer
.. autoattribute:: tvm.contrib.relay_viz.PlotterBackend.BOKEH
.. autoattribute:: tvm.contrib.relay_viz.PlotterBackend.TERMINAL
.. automodule:: tvm.contrib.relay_viz.plotter
:members:
.. automodule:: tvm.contrib.relay_viz.node_edge_gen
:members:

tvm.contrib.rocblas
~~~~~~~~~~~~~~~~~~~
Expand Down
162 changes: 162 additions & 0 deletions gallery/how_to/work_with_relay/using_relay_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=line-too-long
"""
Use Relay Visualizer to Visualize Relay
============================================================
**Author**: `Chi-Wei Wang <https://github.com/chiwwang>`_
This is an introduction about using Relay Visualizer to visualize a Relay IR module.
Relay IR module can contain lots of operations. Although individual
operations are usually easy to understand, they become complicated quickly
when you put them together. It could get even worse while optimiztion passes
come into play.
This utility abstracts an IR module as graphs containing nodes and edges.
It provides a default parser to interpret an IR modules with nodes and edges.
Two renderer backends are also implemented to visualize them.
Here we use a backend showing Relay IR module in the terminal for illustation.
It is a much more lightweight compared to another backend using `Bokeh <https://docs.bokeh.org/en/latest/>`_.
See ``<TVM_HOME>/python/tvm/contrib/relay_viz/README.md``.
Also we will introduce how to implement customized parsers and renderers through
some interfaces classes.
"""
from typing import (
Dict,
Union,
Tuple,
List,
)
import tvm
from tvm import relay
from tvm.contrib import relay_viz
from tvm.contrib.relay_viz.node_edge_gen import (
VizNode,
VizEdge,
NodeEdgeGenerator,
)
from tvm.contrib.relay_viz.terminal import (
TermNodeEdgeGenerator,
TermGraph,
TermPlotter,
)

######################################################################
# Define a Relay IR Module with multiple GlobalVar
# ------------------------------------------------
# Let's build an example Relay IR Module containing multiple ``GlobalVar``.
# We define an ``add`` function and call it in the main function.
data = relay.var("data")
bias = relay.var("bias")
add_op = relay.add(data, bias)
add_func = relay.Function([data, bias], add_op)
add_gvar = relay.GlobalVar("AddFunc")

input0 = relay.var("input0")
input1 = relay.var("input1")
input2 = relay.var("input2")
add_01 = relay.Call(add_gvar, [input0, input1])
add_012 = relay.Call(add_gvar, [input2, add_01])
main_func = relay.Function([input0, input1, input2], add_012)
main_gvar = relay.GlobalVar("main")

mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func})

######################################################################
# Render the graph with Relay Visualizer on the terminal
# ------------------------------------------------------
# The terminal backend can show a Relay IR module as in a text-form
# similar to `clang ast-dump <https://clang.llvm.org/docs/IntroductionToTheClangAST.html#examining-the-ast>`_.
# We should see ``main`` and ``AddFunc`` function. ``AddFunc`` is called twice in the ``main`` function.
viz = relay_viz.RelayVisualizer(mod, {}, relay_viz.PlotterBackend.TERMINAL)
viz.render()

######################################################################
# Customize Parser for Interested Relay Types
# -------------------------------------------
# Sometimes the information shown by the default implementation is not suitable
# for a specific usage. It is possible to provide your own parser and renderer.
# Here demostrate how to customize parsers for ``relay.var``.
# We need to implement :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` interface.
class YourAwesomeParser(NodeEdgeGenerator):
def __init__(self):
self._org_parser = TermNodeEdgeGenerator()

def get_node_edges(
self,
node: relay.Expr,
relay_param: Dict[str, tvm.runtime.NDArray],
node_to_id: Dict[relay.Expr, Union[int, str]],
) -> Tuple[Union[VizNode, None], List[VizEdge]]:

if isinstance(node, relay.Var):
node = VizNode(node_to_id[node], "AwesomeVar", f"name_hint {node.name_hint}")
# no edge is introduced. So return an empty list.
ret = (node, [])
return ret

# delegate other types to the original parser.
return self._org_parser.get_node_edges(node, relay_param, node_to_id)


######################################################################
# Pass a tuple of :py:class:`tvm.contrib.relay_viz.plotter.Plotter` and
# :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` instances
# to ``RelayVisualizer``. Here we re-use the Plotter interface implemented inside
# ``relay_viz.terminal`` module.
viz = relay_viz.RelayVisualizer(mod, {}, (TermPlotter(), YourAwesomeParser()))
viz.render()

######################################################################
# More Customization around Graph and Plotter
# -------------------------------------------
# All ``RelayVisualizer`` care about are interfaces defined in ``plotter.py`` and
# ``node_edge_generator.py``. We can override them to introduce custimized logics.
# For example, if we want the Graph to duplicate above ``AwesomeVar`` while it is added,
# we can override ``relay_viz.terminal.TermGraph.node``.
class AwesomeGraph(TermGraph):
def node(self, node_id, node_type, node_detail):
# add original node first
super().node(node_id, node_type, node_detail)
if node_type == "AwesomeVar":
duplicated_id = f"duplciated_{node_id}"
duplicated_type = "double AwesomeVar"
super().node(duplicated_id, duplicated_type, "")
# connect the duplicated var to the original one
super().edge(duplicated_id, node_id)


# override TermPlotter to return `AwesomeGraph` instead
class AwesomePlotter(TermPlotter):
def create_graph(self, name):
self._name_to_graph[name] = AwesomeGraph(name)
return self._name_to_graph[name]


viz = relay_viz.RelayVisualizer(mod, {}, (AwesomePlotter(), YourAwesomeParser()))
viz.render()

######################################################################
# Summary
# -------
# This tutorial demonstrates the usage of Relay Visualizer.
# The class :py:class:`tvm.contrib.relay_viz.RelayVisualizer` is composed of interfaces
# defined in ``plotter.py`` and ``node_edge_generator.py``. It provides a single entry point
# while keeping the possibility of implementing customized visualizer in various cases.
#
10 changes: 7 additions & 3 deletions python/tvm/contrib/relay_viz/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ This tool target to visualize Relay IR.

## Requirement

### Terminal Backend
1. TVM

### Bokeh Backend
1. TVM
2. graphviz
2. pydot
Expand Down Expand Up @@ -66,9 +70,9 @@ This utility is composed of two parts: `node_edge_gen.py` and `plotter.py`.

`plotter.py` define interfaces of `Graph` and `Plotter`. `Plotter` is responsible to render a collection of `Graph`.

`node_edge_gen.py` define interfaces of converting Relay IR modules to nodes/edges consumed by `Graph`. Further, this python module also provide a default implementation for common relay types.
`node_edge_gen.py` define interfaces of converting Relay IR modules to nodes and edges. Further, this python module provide a default implementation for common relay types.

If customization is wanted for a certain relay type, we can implement the `NodeEdgeGenerator` interface, handling that relay type accordingly, and delegate other types to the default implementation. See `_terminal.py` for an example usage.

These two interfaces are glued by the top level class `RelayVisualizer`, which passes a relay module to `NodeEdgeGenerator` and add nodes/edges to `Graph`.
Then, it render the plot by `Plotter.render()`.
These two interfaces are glued by the top level class `RelayVisualizer`, which passes a relay module to `NodeEdgeGenerator` and add nodes and edges to `Graph`.
Then, it render the plot by calling `Plotter.render()`.
36 changes: 21 additions & 15 deletions python/tvm/contrib/relay_viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,33 @@


class PlotterBackend(Enum):
"""Enumeration for available plotters."""
"""Enumeration for available plotter backends."""

BOKEH = "bokeh"
TERMINAL = "terminal"


class RelayVisualizer:
"""Relay IR Visualizer"""
"""Relay IR Visualizer
Parameters
----------
relay_mod : tvm.IRModule
Relay IR module.
relay_param: None | Dict[str, tvm.runtime.NDArray]
Relay parameter dictionary. Default `None`.
backend: PlotterBackend | Tuple[Plotter, NodeEdgeGenerator]
The backend used to render graphs. It can be a tuple of an implemented Plotter instance and
NodeEdgeGenerator instance to introduce customized parsing and visualization logics.
Default ``PlotterBackend.TERMINAL``.
"""

def __init__(
self,
relay_mod: tvm.IRModule,
relay_param: Union[None, Dict[str, tvm.runtime.NDArray]] = None,
backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] = PlotterBackend.TERMINAL,
):
"""Visualize Relay IR.
Parameters
----------
relay_mod : tvm.IRModule, Relay IR module
relay_param: None | Dict[str, tvm.runtime.NDArray], Relay parameter dictionary. Default `None`.
backend: PlotterBackend | Tuple[Plotter, NodeEdgeGenerator], Default `PlotterBackend.TERMINAL`.
"""

self._plotter, self._ne_generator = get_plotter_and_generator(backend)
self._relay_param = relay_param if relay_param is not None else {}
Expand Down Expand Up @@ -83,8 +87,10 @@ def _add_nodes(self, graph, node_to_id, relay_param):
Parameters
----------
graph : `plotter.Graph`
graph : plotter.Graph
node_to_id : Dict[relay.expr, str | int]
relay_param : Dict[str, tvm.runtime.NDarray]
"""
for node in node_to_id:
Expand All @@ -102,11 +108,11 @@ def get_plotter_and_generator(backend):
"""Specify the Plottor and its NodeEdgeGenerator"""
if isinstance(backend, (tuple, list)) and len(backend) == 2:
if not isinstance(backend[0], Plotter):
raise ValueError(f"First element of backend should be derived from {type(Plotter)}")
raise ValueError(f"First element should be an instance derived from {type(Plotter)}")

if not isinstance(backend[1], NodeEdgeGenerator):
raise ValueError(
f"Second element of backend should be derived from {type(NodeEdgeGenerator)}"
f"Second element should be an instance derived from {type(NodeEdgeGenerator)}"
)

return backend
Expand All @@ -118,7 +124,7 @@ def get_plotter_and_generator(backend):
# Basically we want to keep them optional. Users can choose plotters they want to install.
if backend == PlotterBackend.BOKEH:
# pylint: disable=import-outside-toplevel
from ._bokeh import (
from .bokeh import (
BokehPlotter,
BokehNodeEdgeGenerator,
)
Expand All @@ -127,7 +133,7 @@ def get_plotter_and_generator(backend):
ne_generator = BokehNodeEdgeGenerator()
elif backend == PlotterBackend.TERMINAL:
# pylint: disable=import-outside-toplevel
from ._terminal import (
from .terminal import (
TermPlotter,
TermNodeEdgeGenerator,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,9 @@
import functools
import logging

_LOGGER = logging.getLogger(__name__)

import numpy as np

try:
import pydot
except ImportError:
_LOGGER.critical("pydot library is required. You might want to run pip install pydot.")
raise

try:
from bokeh.io import output_file, save
except ImportError:
_LOGGER.critical("bokeh library is required. You might want to run pip install bokeh.")
raise

import pydot
from bokeh.io import output_file, save
from bokeh.models import (
ColumnDataSource,
CustomJS,
Expand Down Expand Up @@ -63,6 +50,8 @@

from .node_edge_gen import DefaultNodeEdgeGenerator

_LOGGER = logging.getLogger(__name__)

# Use default node/edge generator
BokehNodeEdgeGenerator = DefaultNodeEdgeGenerator

Expand Down
Loading

0 comments on commit 9955d7c

Please sign in to comment.