diff --git a/docs/reference/api/python/contrib.rst b/docs/reference/api/python/contrib.rst index 52d3faff0fc4..26b5abb97ffa 100644 --- a/docs/reference/api/python/contrib.rst +++ b/docs/reference/api/python/contrib.rst @@ -97,10 +97,12 @@ tvm.contrib.relay_viz ~~~~~~~~~~~~~~~~~~~~~ .. automodule:: tvm.contrib.relay_viz :members: -.. automodule:: tvm.contrib.relay_viz.interface +.. automodule:: tvm.contrib.relay_viz.dot :members: .. automodule:: tvm.contrib.relay_viz.terminal :members: +.. automodule:: tvm.contrib.relay_viz.interface + :members: tvm.contrib.rocblas diff --git a/gallery/how_to/work_with_relay/using_relay_viz.py b/gallery/how_to/work_with_relay/using_relay_viz.py index f61fc41f4f14..10e6dab12e24 100644 --- a/gallery/how_to/work_with_relay/using_relay_viz.py +++ b/gallery/how_to/work_with_relay/using_relay_viz.py @@ -32,6 +32,8 @@ Here we use a renderer rendering graph in the text-form. It is a lightweight, AST-like visualizer, inspired by `clang ast-dump `_. We will introduce how to implement customized parsers and renderers through interface classes. + +For more details, please refer to :py:mod:`tvm.contrib.relay_viz`. """ from typing import ( Dict, diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index 32814b577d0d..fb4dac226d57 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -27,6 +27,10 @@ TermPlotter, TermVizParser, ) +from .dot import ( + DotPlotter, + DotVizParser, +) class RelayVisualizer: @@ -69,12 +73,16 @@ def __init__( node_to_id = {} # callback to generate an unique string-ID for nodes. + # node_count_offset ensure each node ID is still unique across subgraph. + node_count_offset = 0 + def traverse_expr(node): if node in node_to_id: return - node_to_id[node] = str(len(node_to_id)) + node_to_id[node] = str(len(node_to_id) + node_count_offset) for name in graph_names: + node_count_offset += len(node_to_id) node_to_id.clear() relay.analysis.post_order_visit(relay_mod[name], traverse_expr) graph = self._plotter.create_graph(name) diff --git a/python/tvm/contrib/relay_viz/dot.py b/python/tvm/contrib/relay_viz/dot.py new file mode 100644 index 000000000000..a9e98189a85a --- /dev/null +++ b/python/tvm/contrib/relay_viz/dot.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Visualize Relay IR by Graphviz DOT language.""" + +from typing import ( + Any, + Callable, + Dict, +) +from .interface import ( + DefaultVizParser, + Plotter, + VizEdge, + VizGraph, + VizNode, +) + +try: + import graphviz +except ImportError: + # add "from None" to silence + # "During handling of the above exception, another exception occurred" + raise ImportError( + "The graphviz package is required for DOT renderer. " + "Please install it first. For example, pip3 install graphviz" + ) from None + +DotVizParser = DefaultVizParser + + +class DotGraph(VizGraph): + """DOT graph for relay IR. + + See also :py:class:`tvm.contrib.relay_viz.dot.DotPlotter` + + Parameters + ---------- + name: str + name of this graph. + graph_attr: Optional[Dict[str, str]] + key-value pairs for the graph. + node_attr: Optional[Dict[str, str]] + key-value pairs for all nodes. + edge_attr: Optional[Dict[str, str]] + key-value pairs for all edges. + get_node_attr: Optional[Callable[[VizNode], Dict[str, str]]] + A callable returning attributes for the node. + """ + + def __init__( + self, + name: str, + graph_attr: Dict[str, str] = None, + node_attr: Dict[str, str] = None, + edge_attr: Dict[str, str] = None, + get_node_attr: Callable[[VizNode], Dict[str, str]] = None, + ): + self._name = name + self._get_node_attr = self._default_get_node_attr + if get_node_attr is not None: + self._get_node_attr = get_node_attr + + # graphviz recognizes the subgraph as a cluster subgraph + # by the name starting with "cluster" (all lowercase) + self._digraph = graphviz.Digraph( + name=f"cluster_{self._name}", + graph_attr=graph_attr, + node_attr=node_attr, + edge_attr=edge_attr, + ) + self._digraph.attr(label=self._name) + + def node(self, viz_node: VizNode) -> None: + """Add a node to the underlying graph. + Nodes in a Relay IR Module are expected to be added in the post-order. + + Parameters + ---------- + viz_node : VizNode + A `VizNode` instance. + """ + self._digraph.node( + viz_node.identity, + f"{viz_node.type_name}\n{viz_node.detail}", + **self._get_node_attr(viz_node), + ) + + def edge(self, viz_edge: VizEdge) -> None: + """Add an edge to the underlying graph. + + Parameters + ---------- + viz_edge : VizEdge + A `VizEdge` instance. + """ + self._digraph.edge(viz_edge.start, viz_edge.end) + + @property + def digraph(self): + return self._digraph + + @staticmethod + def _default_get_node_attr(node: VizNode): + if "Var" in node.type_name: + return {"shape": "ellipse"} + return {"shape": "box"} + + +class DotPlotter(Plotter): + """DOT language graph plotter + + The plotter accepts various graphviz attributes for graphs, nodes, and edges. + Please refer to https://graphviz.org/doc/info/attrs.html for available attributes. + + Parameters + ---------- + graph_attr: Optional[Dict[str, str]] + key-value pairs for all graphs. + node_attr: Optional[Dict[str, str]] + key-value pairs for all nodes. + edge_attr: Optional[Dict[str, str]] + key-value pairs for all edges. + get_node_attr: Optional[Callable[[VizNode], Dict[str, str]]] + A callable returning attributes for a specific node. + render_kwargs: Optional[Dict[str, Any]] + keyword arguments directly passed to `graphviz.Digraph.render()`. + + Examples + -------- + + .. code-block:: python + + from tvm.contrib import relay_viz + from tvm.relay.testing import resnet + + mod, param = resnet.get_workload(num_layers=18) + # graphviz attributes + graph_attr = {"color": "red"} + node_attr = {"color": "blue"} + edge_attr = {"color": "black"} + + # VizNode is passed to the callback. + # We want to color NCHW conv2d nodes. Also give Var a different shape. + def get_node_attr(node): + if "nn.conv2d" in node.type_name and "NCHW" in node.detail: + return { + "fillcolor": "green", + "style": "filled", + "shape": "box", + } + if "Var" in node.type_name: + return {"shape": "ellipse"} + return {"shape": "box"} + + # Create plotter and pass it to viz. Then render the graph. + dot_plotter = relay_viz.DotPlotter( + graph_attr=graph_attr, + node_attr=node_attr, + edge_attr=edge_attr, + get_node_attr=get_node_attr) + + viz = relay_viz.RelayVisualizer( + mod, + relay_param=param, + plotter=dot_plotter, + parser=relay_viz.DotVizParser()) + viz.render("hello") + """ + + def __init__( + self, + graph_attr: Dict[str, str] = None, + node_attr: Dict[str, str] = None, + edge_attr: Dict[str, str] = None, + get_node_attr: Callable[[VizNode], Dict[str, str]] = None, + render_kwargs: Dict[str, Any] = None, + ): + self._name_to_graph = {} + self._graph_attr = graph_attr + self._node_attr = node_attr + self._edge_attr = edge_attr + self._get_node_attr = get_node_attr + + self._render_kwargs = {} if render_kwargs is None else render_kwargs + + def create_graph(self, name): + self._name_to_graph[name] = DotGraph( + name, self._graph_attr, self._node_attr, self._edge_attr, self._get_node_attr + ) + return self._name_to_graph[name] + + def render(self, filename: str = None): + """render the graph generated from the Relay IR module. + + This function is a thin wrapper of `graphviz.Digraph.render()`. + """ + # Create or update the filename + if filename is not None: + self._render_kwargs["filename"] = filename + # default cleanup + if "cleanup" not in self._render_kwargs: + self._render_kwargs["cleanup"] = True + + root_graph = graphviz.Digraph() + for graph in self._name_to_graph.values(): + root_graph.subgraph(graph.digraph) + root_graph.render(**self._render_kwargs) diff --git a/python/tvm/contrib/relay_viz/interface.py b/python/tvm/contrib/relay_viz/interface.py index 6e52f024b1c5..45bb7758c0b8 100644 --- a/python/tvm/contrib/relay_viz/interface.py +++ b/python/tvm/contrib/relay_viz/interface.py @@ -48,7 +48,7 @@ def __init__(self, node_id: str, node_type: str, node_detail: str): self._detail = node_detail @property - def identity(self) -> Union[int, str]: + def identity(self) -> str: return self._id @property @@ -59,6 +59,10 @@ def type_name(self) -> str: def detail(self) -> str: return self._detail + def __repr__(self) -> str: + detail = self._detail.replace("\n", ", ") + return f"VizNode(identity: {self._id}, type_name: {self._type}, detail: {detail}" + class VizEdge: """VizEdge connect two `VizNode`. @@ -139,7 +143,7 @@ def edge(self, viz_edge: VizEdge) -> None: Parameters ---------- - id_start : VizEdge + viz_edge : VizEdge A `VizEdge` instance. """ @@ -277,7 +281,7 @@ def _tuple_get_item( node_id = node_to_id[node] # Tuple -> TupleGetItemNode - viz_node = VizNode(node_id, f"TupleGetItem", "idx: {node.index}") + viz_node = VizNode(node_id, f"TupleGetItem", f"idx: {node.index}") viz_edges = [VizEdge(node_to_id[node.tuple_value], node_id)] return viz_node, viz_edges diff --git a/python/tvm/contrib/relay_viz/terminal.py b/python/tvm/contrib/relay_viz/terminal.py index 7b72d9da4333..f137bbf9d41c 100644 --- a/python/tvm/contrib/relay_viz/terminal.py +++ b/python/tvm/contrib/relay_viz/terminal.py @@ -31,10 +31,11 @@ VizEdge, VizGraph, VizNode, + VizParser, ) -class TermVizParser(DefaultVizParser): +class TermVizParser(VizParser): """`TermVizParser` parse nodes and edges for `TermPlotter`.""" def __init__(self): @@ -166,7 +167,7 @@ def edge(self, viz_edge: VizEdge) -> None: Parameters ---------- - id_start : VizEdge + viz_edge : VizEdge A `VizEdge` instance. """ # Take CallNode as an example, instead of "arguments point to CallNode",