From 7c9d5dd11455d6ac3e7f3b8720f08acd2b5df725 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Fri, 14 May 2021 22:34:55 +0800 Subject: [PATCH 01/18] Fix AttributeError when TEST_DATA_ROOT_PATH is set Initiate a Path object from TEST_DATA_ROOT_PATH to fix the error: AttributeError: 'str' object has no attribute 'mkdir' --- python/tvm/contrib/download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/download.py b/python/tvm/contrib/download.py index dd7967ff7f0b..f7c68a99229e 100644 --- a/python/tvm/contrib/download.py +++ b/python/tvm/contrib/download.py @@ -121,7 +121,7 @@ def _download_progress(count, block_size, total_size): if "TEST_DATA_ROOT_PATH" in environ: - TEST_DATA_ROOT_PATH = environ.get("TEST_DATA_ROOT_PATH") + TEST_DATA_ROOT_PATH = Path(environ.get("TEST_DATA_ROOT_PATH")) else: TEST_DATA_ROOT_PATH = Path(Path("~").expanduser(), ".tvm_test_data") TEST_DATA_ROOT_PATH.mkdir(parents=True, exist_ok=True) From 43d40c42c96d01f6a72ecbce3c78fd2ebd75051e Mon Sep 17 00:00:00 2001 From: chiwwang Date: Wed, 9 Jun 2021 09:37:50 +0800 Subject: [PATCH 02/18] [DOCS] Add docs for Pass Instrument - Add a tutorial about how to use pass instrument. - Add related sections in Pass Infrastructure documents. --- docs/api/python/ir.rst | 8 ++ docs/dev/pass_infra.rst | 105 +++++++++++++- python/tvm/ir/instrument.py | 13 +- tutorials/dev/use_pass_instrument.py | 204 +++++++++++++++++++++++++++ 4 files changed, 321 insertions(+), 9 deletions(-) create mode 100644 tutorials/dev/use_pass_instrument.py diff --git a/docs/api/python/ir.rst b/docs/api/python/ir.rst index c2a1a1e106d5..348939e4e6fb 100644 --- a/docs/api/python/ir.rst +++ b/docs/api/python/ir.rst @@ -23,6 +23,14 @@ tvm.ir :autosummary: +tvm.ir.instrument +------ +.. automodule:: tvm.ir.instrument + :members: + :imported-members: + :autosummary: + + tvm.transform ------------- .. automodule:: tvm.transform diff --git a/docs/dev/pass_infra.rst b/docs/dev/pass_infra.rst index 67ef30a29504..c96a622fb957 100644 --- a/docs/dev/pass_infra.rst +++ b/docs/dev/pass_infra.rst @@ -109,7 +109,8 @@ configure the compilation options, including optimization level and required/disabled passes, etc. For instance, we may have a configuration which performs all passes at ``opt_level=3`` with some disabled passes using ``disabled_pass=xx`` provided by ``PassContext``. Now we could glob all passes -at ``opt_level=3`` and exclude those in the disabled pass list. +at ``opt_level=3`` and exclude those in the disabled pass list. ``PassContext`` +also provides pass-instruments mechanism, which will be introduced latter. This class is designed for users to conveniently write the Python ``with`` syntax to perform optimizations under a certain configuration. In addition, the @@ -123,16 +124,23 @@ Python APIs to create a compilation pipeline using pass context. class PassContextNode : public Object { public: - ErrorReporter err_reporter; int opt_level{2}; tvm::Array required_pass; tvm::Array disabled_pass; + mutable Optional diag_ctx; + Map config; + Array instruments; }; class PassContext : public NodeRef { public: TVM_DLL static PassContext Create(); TVM_DLL static PassContext Current(); + TVM_DLL void InstrumentEnterPassContext(); + TVM_DLL void InstrumentExitPassContext(); + TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; + TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; + TVM_DLL bool PassEnabled(const PassInfo& info) const; /* Other fields are omitted. */ private: @@ -389,6 +397,51 @@ To allow other C++ modules to apply this pass, we declare a free function in TVM_DLL Pass FoldConstant(); +Pass Instrument +~~~~~~~~~~~~~~~ + +To instrument passes, four methods are introduced to ``PassContext``. + +.. code:: c++ + + TVM_DLL void InstrumentEnterPassContext(); + TVM_DLL void InstrumentExitPassContext(); + TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; + TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; + +The first two methods are called respectively in entering/exiting context scope. The latter two are called while passes is being applied(`src/ir/transform.cc`_). + +Note that ``InstrumentBeforePass()`` return a boolean indicating this pass should +be run or not. + +``PassInstrument`` provides callbacks run by these methods. Multiple +``PassInstrument`` instances can be registed into a single ``PassContext``. +They are called sequentially in the order of ``instruments`` member. + +.. code:: c++ + + namespace instrument { + + class PassInstrumentNode : public Object { + public: + String name; + virtual void EnterPassContext() const = 0; + virtual void ExitPassContext() const = 0; + virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0; + virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0; + virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0; + /* Other fields are omitted. */ + }; + + class PassInstrument : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode); + }; + + } // namespace instrument + +Python interfaces are provided to implement ``PassInstrument`` quickly. + Python Frontend ~~~~~~~~~~~~~~~ @@ -526,6 +579,51 @@ decorators and then invoke it. For more examples about how to customize your own optimization pipeline and debug Relay and tir passes, please refer to the `use pass infra`_ tutorial. +Pass Instrument +^^^^^^^^^^^^^^^ + +A customizable framework to instrument passes is provided. ``PassInstrument`` classes can be registered while constructing ``PassContext``. + +.. code:: python + + @tvm._ffi.register_object("transform.PassContext") + class PassContext(tvm.runtime.Object): + def __init__( + self, + opt_level=2, + required_pass=None, + disabled_pass=None, + instruments=None, + config=None, + ): + # ... + +One can implement a ``PassInstrument`` by ``pass_instrument`` decorator(`python/tvm/ir/instrument.py`_) with a class implementing following methods: + +- ``enter_pass_ctx`` + + * This callback is run at the moement of entering ``PassContext``. + +- ``exit_pass_ctx`` + + * This callback is run at the moement of exiting ``PassContext``. + +- ``should_run`` + + * This callback is run before a pass is executed, returning a boolean indicating if the pass should be run. + * If a pass is listed as required, this callback will not be executed for that pass. + +- ``run_before_pass`` + + * If a pass should be run, this callback is run just before pass execution. + +- ``run_after_pass`` + + * This callback is run right after a pass has been executed. + + +`use pass instrument`_ tutorial provides examples for how to implement ``PassInstrument`` with Python APIs. + .. _Sequential: https://pytorch.org/docs/stable/nn.html?highlight=sequential#torch.nn.Sequential .. _Block: https://mxnet.apache.org/api/python/docs/api/gluon/block.html#gluon-block @@ -544,6 +642,9 @@ optimization pipeline and debug Relay and tir passes, please refer to the .. _python/tvm/ir/transform.py: https://github.com/apache/tvm/blob/main/python/tvm/ir/transform.py +.. _python/tvm/ir/instrument.py: https://github.com/apache/tvm/blob/main/python/tvm/ir/instrument.py + .. _src/tir/transforms/unroll_loop.cc: https://github.com/apache/tvm/blob/main/src/tir/transforms/unroll_loop.cc .. _use pass infra: https://github.com/apache/tvm/blob/main/tutorials/dev/use_pass_infra.py +.. _use pass instrument: https://github.com/apache/tvm/blob/main/tutorials/dev/use_pass_instrument.py diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index c322f2bef3fc..ab2fcb9591cc 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -30,11 +30,8 @@ class PassInstrument(tvm.runtime.Object): """A pass instrument implementation. Users don't need to interact with this class directly. - Instead, a `PassInstrument` instance should be created through `pass_instrument`. - - See Also - -------- - `pass_instrument` + Instead, a `PassInstrument` instance should be created through + :py:func:`pass_instrument` """ @@ -91,13 +88,15 @@ def pass_instrument(pi_cls=None): Parameters ---------- - pi_class : + pi_class : class + Instrument class. See example below. Examples -------- - The following code block decorates a pass instrument class. + The following code block show how to decorate a pass instrument class. .. code-block:: python + @tvm.instrument.pass_instrument class SkipPass: def __init__(self, skip_pass_name): diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py new file mode 100644 index 000000000000..59bc68edf19b --- /dev/null +++ b/tutorials/dev/use_pass_instrument.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long +""" +.. _tutorial-use-pass-instrument: + +How to Use TVM Pass Instrument +============================== +**Author**: `Chi-Wei Wang `_ + +As more and more passes are implemented, it becomes interesting to instrument +passes execution, analyze per-pass effects and observe various events. +We have extended :py:class:`tvm.transform.PassContext` to accept a list of +instrument classes. Also a decorator :py:func:`tvm.ir.instrument.pass_instrument` is provided to easily implement instrument classes. + +This tutorial demostrates how developers can use ``PassContext`` to instrument +passes. For more details, please refer to the :ref:`pass-infra` +""" +import tvm +from tvm import relay +from tvm.contrib.download import download_testdata +from tvm.relay.build_module import bind_params_by_name +from tvm.ir.instrument import ( + PassTimingInstrument, + pass_instrument, +) + +############################################################################### +# Create An Example Relay Program +# ------------------------------- +# We create a Relay program from a Pytorch model. +# Here we pick up ``mobilenet_v2`` from torchvision. +import torch +import torchvision + +model_name = "mobilenet_v2" +model = getattr(torchvision.models, model_name)(pretrained=True) +model = model.eval() + +input_shape = [1, 3, 224, 224] +input_data = torch.randn(input_shape) +scripted_model = torch.jit.trace(model, input_data).eval() + +shape_list = [("input0", input_shape)] +relay_mod, relay_params = relay.frontend.from_pytorch(scripted_model, shape_list) + + +############################################################################### +# Create PassContext With Instruments +# ----------------------------------- +# It is as simple as passing ``instruments`` argument to ``PassContext`` constructor. +# A built-in ``PassTimingInstrument`` is used to profile the execution time of +# each passes. +timing_inst = PassTimingInstrument() +with tvm.transform.PassContext(instruments=[timing_inst]): + relay_mod = relay.transform.InferType()(relay_mod) + relay_mod = relay.transform.FoldScaleAxis()(relay_mod) + # before exiting the context, get profile results. + profiles = timing_inst.render() +print(profiles) + + +############################################################################### +# Create Customized Instrument Class +# ---------------------------------- +# A customized instrument class can be easily created by +# :py:func:`tvm.ir.instrument.pass_instrument` decorator. +# +# Let's create an instrument class which calculate the difference of ``CallNode`` +# counting per ``op.name`` before and after passes. + +# decorate the class +@pass_instrument +class RelayCallNodeDiffer: + + def __init__(self): + self._op_diff = [] + # Passes can be nested. + # Use stack to make sure we get correct before/after pairs. + self._op_cnt_before_stack = [] + + def enter_pass_ctx(self): + self._op_diff = [] + self._op_cnt_before_stack = [] + + def exit_pass_ctx(self): + assert len(self._op_cnt_before_stack) == 0, \ + "The stack is not empty. Something wrong." + + def run_before_pass(self, mod, info): + cur_depth = len(self._op_cnt_before_stack) + self._op_cnt_before_stack.append((info.name, self._count_nodes(mod))) + + def run_after_pass(self, mod, info): + # Pop out the latest recorded pass. + name_before, op_to_cnt_before = self._op_cnt_before_stack.pop() + assert name_before == info.name, \ + "name_before: {}, info.name: {} doesn't match".format(name_before, info.name) + cur_depth = len(self._op_cnt_before_stack) + op_to_cnt_after = self._count_nodes(mod) + op_diff = self._diff(op_to_cnt_after, op_to_cnt_before) + # only record passes causing differences. + if op_diff: + self._op_diff.append((cur_depth, info.name, op_diff)) + + def get_pass_to_op_diff(self): + """ + return [ + (depth, pass_name, {op_name: diff_num, ...}), ... + ] + """ + return self._op_diff + + @staticmethod + def _count_nodes(mod): + ret = {} + def visit(node): + if isinstance(node, relay.expr.Call): + try: + op_name = node.op.name + except AttributeError: + # Some CallNode may not have 'name' such as relay.Function + return + try: + ret[op_name] += 1 + except KeyError: + ret[op_name] = 1 + relay.analysis.post_order_visit(mod["main"], visit) + return ret + + @staticmethod + def _diff(d_after, d_before): + # d_after - d_before + ret = {} + key_after, key_before = set(d_after), set(d_before) + for k in key_before & key_after: + tmp = d_after[k] - d_before[k] + if tmp: + ret[k] = d_after[k] - d_before[k] + for k in key_after - key_before: + ret[k] = d_after[k] + for k in key_before - key_after: + ret[k] = -d_before[k] + return ret + + +############################################################################### +# Apply Passes and Multiple Instrument Classes +# -------------------------------------------- +# Apply any pass you wish. Here :py:class:`tvm.relay.transform.ConvertLayout` +# and :py:class:`tvm.relay.transform.FoldConstant` are used. +# +# ``ConvertLayout`` might add ``layout_transform`` Op while ``FoldConstant`` can +# reduce the number of ``CallNode``. +# +# We can also use multiple instrument classes in a ``PassContext``. +# However, it should be noted that instrument methods are executed sequentially, +# obeying the order of ``instruments`` argument. +# So for instrument classes like ``PassTimingInstrument``, it is inevitable to +# count-up the execution time of other instrument classes to the final +# profile result. +call_node_inst = RelayCallNodeDiffer() +desired_layouts = { + "nn.conv2d": ["NHWC", "HWIO"], +} +# Because layout_transform may be added as a successor of Constant, +# we run FoldConstant twice. +# Though it is obvious only the FoldConstant after the ConvertLayout matter, +# we want to show how many layout_transform is added as a successor of +# Constant. +pass_seq = tvm.transform.Sequential([ + relay.transform.FoldConstant(), + relay.transform.ConvertLayout(desired_layouts), + relay.transform.FoldConstant(), +]) +# bind parameters to make VarNode as ConstantNode. +relay_mod["main"] = bind_params_by_name(relay_mod["main"], relay_params) +# timing_inst is put after call_node_inst. +# So the execution time of ``call_node.inst.run_after_pass()`` is also counted. +with tvm.transform.PassContext(opt_level=3, + instruments=[call_node_inst, timing_inst]): + relay_mod = pass_seq(relay_mod) + profiles = timing_inst.render() +# Uncomment the next line to see timing-profile results. +#print(profiles) +# +# We can see how many CallNode increase/decrease per op type. +# +from pprint import pprint +pprint(call_node_inst.get_pass_to_op_diff()) From 42df16ec412572d19c040fc3d1c1a02b828d5941 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Wed, 9 Jun 2021 16:16:58 +0800 Subject: [PATCH 03/18] Fix ir.rst, the length of separator. --- docs/api/python/ir.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/python/ir.rst b/docs/api/python/ir.rst index 348939e4e6fb..46b4e4e03755 100644 --- a/docs/api/python/ir.rst +++ b/docs/api/python/ir.rst @@ -24,7 +24,7 @@ tvm.ir tvm.ir.instrument ------- +----------------- .. automodule:: tvm.ir.instrument :members: :imported-members: From 0485c45ac6439524d2e7c6af82582bfd66b37e8b Mon Sep 17 00:00:00 2001 From: chiwwang Date: Wed, 9 Jun 2021 16:20:21 +0800 Subject: [PATCH 04/18] Fix unused local name --- tutorials/dev/use_pass_instrument.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py index 59bc68edf19b..7eb9965e3e3b 100644 --- a/tutorials/dev/use_pass_instrument.py +++ b/tutorials/dev/use_pass_instrument.py @@ -102,7 +102,6 @@ def exit_pass_ctx(self): "The stack is not empty. Something wrong." def run_before_pass(self, mod, info): - cur_depth = len(self._op_cnt_before_stack) self._op_cnt_before_stack.append((info.name, self._count_nodes(mod))) def run_after_pass(self, mod, info): From 7d81b5874a46a8c68ddaf41b84bb90334e9b0748 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Wed, 9 Jun 2021 16:28:51 +0800 Subject: [PATCH 05/18] Fix linting errors --- tutorials/dev/use_pass_instrument.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py index 7eb9965e3e3b..535ed3099cbc 100644 --- a/tutorials/dev/use_pass_instrument.py +++ b/tutorials/dev/use_pass_instrument.py @@ -86,7 +86,6 @@ # decorate the class @pass_instrument class RelayCallNodeDiffer: - def __init__(self): self._op_diff = [] # Passes can be nested. @@ -94,12 +93,11 @@ def __init__(self): self._op_cnt_before_stack = [] def enter_pass_ctx(self): - self._op_diff = [] - self._op_cnt_before_stack = [] + self._op_diff = [] + self._op_cnt_before_stack = [] def exit_pass_ctx(self): - assert len(self._op_cnt_before_stack) == 0, \ - "The stack is not empty. Something wrong." + assert len(self._op_cnt_before_stack) == 0, "The stack is not empty. Something wrong." def run_before_pass(self, mod, info): self._op_cnt_before_stack.append((info.name, self._count_nodes(mod))) @@ -107,8 +105,9 @@ def run_before_pass(self, mod, info): def run_after_pass(self, mod, info): # Pop out the latest recorded pass. name_before, op_to_cnt_before = self._op_cnt_before_stack.pop() - assert name_before == info.name, \ - "name_before: {}, info.name: {} doesn't match".format(name_before, info.name) + assert name_before == info.name, "name_before: {}, info.name: {} doesn't match".format( + name_before, info.name + ) cur_depth = len(self._op_cnt_before_stack) op_to_cnt_after = self._count_nodes(mod) op_diff = self._diff(op_to_cnt_after, op_to_cnt_before) @@ -127,6 +126,7 @@ def get_pass_to_op_diff(self): @staticmethod def _count_nodes(mod): ret = {} + def visit(node): if isinstance(node, relay.expr.Call): try: @@ -138,6 +138,7 @@ def visit(node): ret[op_name] += 1 except KeyError: ret[op_name] = 1 + relay.analysis.post_order_visit(mod["main"], visit) return ret @@ -181,23 +182,25 @@ def _diff(d_after, d_before): # Though it is obvious only the FoldConstant after the ConvertLayout matter, # we want to show how many layout_transform is added as a successor of # Constant. -pass_seq = tvm.transform.Sequential([ +pass_seq = tvm.transform.Sequential( + [ relay.transform.FoldConstant(), relay.transform.ConvertLayout(desired_layouts), relay.transform.FoldConstant(), -]) + ] +) # bind parameters to make VarNode as ConstantNode. relay_mod["main"] = bind_params_by_name(relay_mod["main"], relay_params) # timing_inst is put after call_node_inst. # So the execution time of ``call_node.inst.run_after_pass()`` is also counted. -with tvm.transform.PassContext(opt_level=3, - instruments=[call_node_inst, timing_inst]): +with tvm.transform.PassContext(opt_level=3, instruments=[call_node_inst, timing_inst]): relay_mod = pass_seq(relay_mod) profiles = timing_inst.render() # Uncomment the next line to see timing-profile results. -#print(profiles) +# print(profiles) # # We can see how many CallNode increase/decrease per op type. # from pprint import pprint pprint(call_node_inst.get_pass_to_op_diff()) + From 4f8b8e495e08156ebffc3b1b17dd89d40e3ac5d2 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Wed, 9 Jun 2021 16:38:29 +0800 Subject: [PATCH 06/18] Fix linting errors --- tutorials/dev/use_pass_instrument.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py index 535ed3099cbc..5f573881a298 100644 --- a/tutorials/dev/use_pass_instrument.py +++ b/tutorials/dev/use_pass_instrument.py @@ -184,9 +184,9 @@ def _diff(d_after, d_before): # Constant. pass_seq = tvm.transform.Sequential( [ - relay.transform.FoldConstant(), - relay.transform.ConvertLayout(desired_layouts), - relay.transform.FoldConstant(), + relay.transform.FoldConstant(), + relay.transform.ConvertLayout(desired_layouts), + relay.transform.FoldConstant(), ] ) # bind parameters to make VarNode as ConstantNode. @@ -202,5 +202,6 @@ def _diff(d_after, d_before): # We can see how many CallNode increase/decrease per op type. # from pprint import pprint + pprint(call_node_inst.get_pass_to_op_diff()) From 3c666b52c5a1f0da4916bb30b433c6f040388e8b Mon Sep 17 00:00:00 2001 From: chiwwang Date: Wed, 9 Jun 2021 16:41:13 +0800 Subject: [PATCH 07/18] Fix linting errors --- tutorials/dev/use_pass_instrument.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py index 5f573881a298..4b749863e9a4 100644 --- a/tutorials/dev/use_pass_instrument.py +++ b/tutorials/dev/use_pass_instrument.py @@ -204,4 +204,3 @@ def _diff(d_after, d_before): from pprint import pprint pprint(call_node_inst.get_pass_to_op_diff()) - From 18317ced61060e9fba6086a53a0ce736db4bbc38 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Fri, 11 Jun 2021 21:41:42 +0800 Subject: [PATCH 08/18] Address code-review feedbacks --- docs/api/python/ir.rst | 6 +- docs/conf.py | 7 +- docs/dev/pass_infra.rst | 103 ++++++++++++++++----- python/tvm/ir/instrument.py | 19 +++- python/tvm/ir/transform.py | 4 +- tests/python/relay/test_pass_instrument.py | 3 - tutorials/dev/use_pass_infra.py | 15 +-- tutorials/dev/use_pass_instrument.py | 70 +++++++++----- 8 files changed, 162 insertions(+), 65 deletions(-) diff --git a/docs/api/python/ir.rst b/docs/api/python/ir.rst index 46b4e4e03755..e7fb3c114689 100644 --- a/docs/api/python/ir.rst +++ b/docs/api/python/ir.rst @@ -23,9 +23,9 @@ tvm.ir :autosummary: -tvm.ir.instrument ------------------ -.. automodule:: tvm.ir.instrument +tvm.instrument +-------------- +.. automodule:: tvm.instrument :members: :imported-members: :autosummary: diff --git a/docs/conf.py b/docs/conf.py index 45f5da670608..f2d3ef9e88fa 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -273,7 +273,12 @@ def git_describe_version(original_version): "tune_network_x86.py", "tune_network_cuda.py", ], - "dev": ["low_level_custom_pass.py", "use_pass_infra.py", "bring_your_own_datatypes.py"], + "dev": [ + "low_level_custom_pass.py", + "use_pass_infra.py", + "bring_your_own_datatypes.py", + "use_pass_instrument.py" + ], } diff --git a/docs/dev/pass_infra.rst b/docs/dev/pass_infra.rst index c96a622fb957..b2a5923f3e88 100644 --- a/docs/dev/pass_infra.rst +++ b/docs/dev/pass_infra.rst @@ -110,7 +110,7 @@ required/disabled passes, etc. For instance, we may have a configuration which performs all passes at ``opt_level=3`` with some disabled passes using ``disabled_pass=xx`` provided by ``PassContext``. Now we could glob all passes at ``opt_level=3`` and exclude those in the disabled pass list. ``PassContext`` -also provides pass-instruments mechanism, which will be introduced latter. +also provides a way to instrument all passes. See section :ref:`pass_instrument_section_tag`. This class is designed for users to conveniently write the Python ``with`` syntax to perform optimizations under a certain configuration. In addition, the @@ -140,7 +140,6 @@ Python APIs to create a compilation pipeline using pass context. TVM_DLL void InstrumentExitPassContext(); TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; - TVM_DLL bool PassEnabled(const PassInfo& info) const; /* Other fields are omitted. */ private: @@ -397,26 +396,14 @@ To allow other C++ modules to apply this pass, we declare a free function in TVM_DLL Pass FoldConstant(); +.. _pass_instrument_section_tag: + Pass Instrument ~~~~~~~~~~~~~~~ -To instrument passes, four methods are introduced to ``PassContext``. - -.. code:: c++ - - TVM_DLL void InstrumentEnterPassContext(); - TVM_DLL void InstrumentExitPassContext(); - TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; - TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; - -The first two methods are called respectively in entering/exiting context scope. The latter two are called while passes is being applied(`src/ir/transform.cc`_). - -Note that ``InstrumentBeforePass()`` return a boolean indicating this pass should -be run or not. - -``PassInstrument`` provides callbacks run by these methods. Multiple -``PassInstrument`` instances can be registed into a single ``PassContext``. -They are called sequentially in the order of ``instruments`` member. +``PassInstrument`` provides callbacks run when entering/exiting ``PassContext`` and before/after executing passes. +Multiple ``PassInstrument`` instances can be registed into a single ``PassContext``. +Instrument instances are called sequentially in the order of ``instruments`` argument passed to ``PassContext``. .. code:: c++ @@ -442,6 +429,70 @@ They are called sequentially in the order of ``instruments`` member. Python interfaces are provided to implement ``PassInstrument`` quickly. +Following four methods are invoked in the life-cycle of ``PassContext``. + +.. code:: c++ + + TVM_DLL void InstrumentEnterPassContext(); + TVM_DLL void InstrumentExitPassContext(); + TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; + TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; + +``InstrumentEnterPassContext`` is called immediately when the scope +of the ``PassContext`` instance is entered. + +``InstrumentExitPassContext`` is called when the scope of ``PassContextNode`` +is being leaved, or exceptions occur during the execution of passes. +This method is also called when instruments is being overriden by ``override_instruments`` in ::py:class:`tvm.transform.PassContext`. + +``InstrumentBeforePass`` is called before pass-execution. +``InstrumentAfterPass`` is called after pass-executioon if the pass should be run. The behavir is like: + +.. code:: c++ + + if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) { + new_ir_module = run_pass(ir_module, pass_ctx); + pass_ctx.InstrumentAfterPass(new_ir_module, pass_info); + return new_ir_module; + } + +Here is a brief introduction of each methods. See (`src/ir/transform.cc`_) for more details. + +- ``InstrumentEnterPassContext`` + + * ``EnterPassContext()`` is executed in the order of ``instruments`` passed to the ``PassContext``. + * When an exception raises, ``PassContext`` disable the pass instrumentation + by clearing all registered ``PassInstrument`` instances. + * Then ``PassContext`` execute ``ExitPassContext()`` method of each ``PassInstrument`` + instances which successfully finished ``EnterPassContext()`` + * For example, if ``PassInstrument`` A, B, and C are registered to a ``PassContext`` + and A finished ``EnterPassContext()`` while B throws an exception, then C + is never executed; ``ExitPassContext()`` of A is executed. + +- ``InstrumentExitPassContext`` + + * ``ExitPassContext()`` of each ``PassInstrument`` instances are executed in + the order of ``instruments`` passed to the ``PassContext``. + * While an exception occurs, ``instruments`` is cleared. + * That means, instances registered after the one throwing exceptions do not execute ``ExitPassContext``. + +- ``InstrumentBeforePass`` + + * ``ShouldRun`` callbakc is executed if the pass is not listed as a required pass. + If the pass is a required pass, ``ShouldRun`` will not be executed for that pass. + * ``RunBeforePass`` is executed in the order of ``instruments`` if the pass is not blocked by ``ShouldRun``. + * Note that ``InstrumentBeforePass`` returns a boolean indicating whether or not the pass should be run. + * When an exception occur, it is thrown immediately. + We rely on Python Context Manager to exit ``PassContext`` safely + (meaning ``ExitPassContext`` of each instruments will be run. For C++, please refer to `include/tvm/support/with.h`_.) + +- ``InstrumentAfterPass`` + + * ``RunAfterPass`` is executed in the order of ``instruments`` passed to the ``PassContext``. + * When an exception occur, it is thrown immediately. + We rely on Python Context Manager or ``With`` class(`include/tvm/support/with.h`_) to exit ``PassContext`` safely + + Python Frontend ~~~~~~~~~~~~~~~ @@ -598,20 +649,21 @@ A customizable framework to instrument passes is provided. ``PassInstrument`` cl ): # ... -One can implement a ``PassInstrument`` by ``pass_instrument`` decorator(`python/tvm/ir/instrument.py`_) with a class implementing following methods: +One can implement a ``PassInstrument`` by using the ``pass_instrument`` decorator(`python/tvm/ir/instrument.py`_) on a class implementing following methods: - ``enter_pass_ctx`` - * This callback is run at the moement of entering ``PassContext``. + * This callback is run when entering ``PassContext``. - ``exit_pass_ctx`` - * This callback is run at the moement of exiting ``PassContext``. + * This callback is run when exiting ``PassContext``. - ``should_run`` - * This callback is run before a pass is executed, returning a boolean indicating if the pass should be run. - * If a pass is listed as required, this callback will not be executed for that pass. + * This callback is run before a pass is executed. It returns a boolean + indicating whether or not the pass should be run. + * If a pass is listed as required, ``should_run`` will not have effect and not be executed. - ``run_before_pass`` @@ -630,6 +682,8 @@ One can implement a ``PassInstrument`` by ``pass_instrument`` decorator(`python/ .. _include/tvm/ir/transform.h: https://github.com/apache/tvm/blob/main/include/tvm/ir/transform.h +.. _include/tvm/support/with.h: https://github.com/apache/tvm/blob/main/include/tvm/support/with.h + .. _src/relay/ir/transform.cc: https://github.com/apache/tvm/blob/main/src/relay/ir/transform.cc .. _src/ir/transform.cc: https://github.com/apache/tvm/blob/main/src/ir/transform.cc @@ -647,4 +701,5 @@ One can implement a ``PassInstrument`` by ``pass_instrument`` decorator(`python/ .. _src/tir/transforms/unroll_loop.cc: https://github.com/apache/tvm/blob/main/src/tir/transforms/unroll_loop.cc .. _use pass infra: https://github.com/apache/tvm/blob/main/tutorials/dev/use_pass_infra.py + .. _use pass instrument: https://github.com/apache/tvm/blob/main/tutorials/dev/use_pass_instrument.py diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index ab2fcb9591cc..23666c1e35d0 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -93,7 +93,7 @@ def pass_instrument(pi_cls=None): Examples -------- - The following code block show how to decorate a pass instrument class. + The following code block shows how to decorate a pass instrument class. .. code-block:: python @@ -142,7 +142,8 @@ def create_pass_instrument(pi_cls): @tvm._ffi.register_object("instrument.PassInstrument") class PassTimingInstrument(tvm.runtime.Object): - """A wrapper to create a passes time instrument that implemented in C++""" + """A wrapper to create a passes time instrument that implemented in C++ + """ def __init__(self): self.__init_handle_by_constructor__(_ffi_instrument_api.MakePassTimingInstrument) @@ -154,5 +155,19 @@ def render(): ------- string : string The rendered string result of time profiles + + Examples + -------- + + The following code-block shows how to use this instrument. + + .. code-block:: python + + timing_inst = PassTimingInstrument() + with tvm.transform.PassContext(instruments=[timing_inst]): + relay_mod = relay.transform.InferType()(relay_mod) + relay_mod = relay.transform.FoldScaleAxis()(relay_mod) + # before exiting the context, get profile results. + profiles = timing_inst.render() """ return _ffi_instrument_api.RenderTimePassProfiles() diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 3a3ac16be677..eb31d58b4428 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -107,8 +107,8 @@ def __exit__(self, ptype, value, trace): def override_instruments(self, instruments): """Override instruments within this PassContext. - If there are existing instruments, their exit_pass_ctx callbacks are called. - Then switching to new instruments and calling new enter_pass_ctx callbacks. + If there are existing instruments, their ``exit_pass_ctx`` callbacks are called. + Then switching to new instruments and calling new ``enter_pass_ctx`` callbacks. instruments : Sequence[PassInstrument] The list of pass instrument implementations. diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index 86283fd31819..4122d4fad9e1 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -243,9 +243,6 @@ def __init__(self, id): def exit_pass_ctx(self): events.append(self.id + " exit ctx") - def exit_pass_ctx(self): - events.append(self.id + " exit ctx") - @pass_instrument class PIBroken(PI): def __init__(self, id): diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 3804b1496d05..d072a4b3ac98 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -261,17 +261,18 @@ def visit_constant(self, c): ] ) +############################################################################### # By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will # dump out the module IR when ``FoldConstant`` is done. Users can plug in this # pass after any pass they want to debug for viewing the optimization effect. # -# There is a more flexible debugging mechanism also exposed by the build configuration -# object. One can pass a tracing function which can be used to execute arbitrary code -# before and/or after each pass. A tracing function will receive a :py::class:`tvm.IRModule`, -# a :py:class:`tvm.transform.PassInfo` object, -# and a boolean indicating whether you are executing before, or after a pass. -# An example is below. - +# There is a more flexible debugging mechanism. One can implement a ``PassInstrument`` +# class to execute arbitrary code not only before and/or after each pass but also +# at entering/exiting ``PassContext``. See :ref:`pass_instrument_section_tag` +# for more details. +# +# Here we use :py::func`tvm.instrument.pass_instrument` decorator to implement +# a PassInsturment class printing IR before execution of each passes: @tvm.instrument.pass_instrument class PrintIR: diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py index 4b749863e9a4..6d8e2695647d 100644 --- a/tutorials/dev/use_pass_instrument.py +++ b/tutorials/dev/use_pass_instrument.py @@ -22,16 +22,19 @@ ============================== **Author**: `Chi-Wei Wang `_ -As more and more passes are implemented, it becomes interesting to instrument +As more and more passes are implemented, it becomes useful to instrument passes execution, analyze per-pass effects and observe various events. -We have extended :py:class:`tvm.transform.PassContext` to accept a list of -instrument classes. Also a decorator :py:func:`tvm.ir.instrument.pass_instrument` is provided to easily implement instrument classes. +Pass infrastructure provides instrument mechanism. One can pass a list of +instrument instances to :py:class:`tvm.transform.PassContext`. +Also a decorator :py:func:`tvm.instrument.pass_instrument` is provided +to easily implement instrument classes. This tutorial demostrates how developers can use ``PassContext`` to instrument -passes. For more details, please refer to the :ref:`pass-infra` +passes. Please also refer to the :ref:`pass-infra`. """ import tvm -from tvm import relay +import tvm.relay as relay +from tvm.relay.testing import resnet from tvm.contrib.download import download_testdata from tvm.relay.build_module import bind_params_by_name from tvm.ir.instrument import ( @@ -39,24 +42,19 @@ pass_instrument, ) + ############################################################################### # Create An Example Relay Program # ------------------------------- -# We create a Relay program from a Pytorch model. -# Here we pick up ``mobilenet_v2`` from torchvision. -import torch -import torchvision - -model_name = "mobilenet_v2" -model = getattr(torchvision.models, model_name)(pretrained=True) -model = model.eval() - -input_shape = [1, 3, 224, 224] -input_data = torch.randn(input_shape) -scripted_model = torch.jit.trace(model, input_data).eval() - -shape_list = [("input0", input_shape)] -relay_mod, relay_params = relay.frontend.from_pytorch(scripted_model, shape_list) +# We use pre-defined resnet-18 network in Relay. +batch_size = 1 +num_of_image_class = 1000 +image_shape = (3, 224, 224) +output_shape = (batch_size, num_of_image_class) +relay_mod, relay_params = resnet.get_workload( + num_layers=18, batch_size=1, image_shape=image_shape +) +print(relay_mod.astext(show_meta_data=False)) ############################################################################### @@ -74,11 +72,36 @@ print(profiles) +############################################################################### +# One can also use the current ``PassContext`` and register +# ``PassInstrument`` instances by ``override_instruments`` method. +# Note that ``override_instruments`` executes ``exit_pass_ctx`` callbacks +# if any instrument already exists. Then it switches to new instruments +# and calls ``enter_pass_ctx`` callbacks of new instruments. +# Refer to following sections and :py:func:`tvm.instrument.pass_instrument` for these callbacks. +cur_pass_ctx = tvm.transform.PassContext.current() +cur_pass_ctx.override_instruments([timing_inst]) +relay_mod = relay.transform.InferType()(relay_mod) +relay_mod = relay.transform.FoldScaleAxis()(relay_mod) +profiles = timing_inst.render() +print(profiles) + + +############################################################################### +# Register empty list to clear instruments. +# +# Note that ``exit_pass_ctx`` of ``PassTimingInstrument`` is called. +# Profiles are cleared so nothing is printed. +cur_pass_ctx.override_instruments([]) +profiles = timing_inst.render() +print(profiles) + + ############################################################################### # Create Customized Instrument Class # ---------------------------------- # A customized instrument class can be easily created by -# :py:func:`tvm.ir.instrument.pass_instrument` decorator. +# :py:func:`tvm.instrument.pass_instrument` decorator. # # Let's create an instrument class which calculate the difference of ``CallNode`` # counting per ``op.name`` before and after passes. @@ -198,9 +221,10 @@ def _diff(d_after, d_before): profiles = timing_inst.render() # Uncomment the next line to see timing-profile results. # print(profiles) -# + + +############################################################################### # We can see how many CallNode increase/decrease per op type. -# from pprint import pprint pprint(call_node_inst.get_pass_to_op_diff()) From 66596a9d18f06c4123e4f6d0c19370d721f0decd Mon Sep 17 00:00:00 2001 From: chiwwang Date: Fri, 11 Jun 2021 22:09:21 +0800 Subject: [PATCH 09/18] Fix linting --- docs/conf.py | 2 +- python/tvm/ir/instrument.py | 3 +-- tutorials/dev/use_pass_infra.py | 1 + tutorials/dev/use_pass_instrument.py | 4 +--- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index f2d3ef9e88fa..d279c11b6c13 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -277,7 +277,7 @@ def git_describe_version(original_version): "low_level_custom_pass.py", "use_pass_infra.py", "bring_your_own_datatypes.py", - "use_pass_instrument.py" + "use_pass_instrument.py", ], } diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index 23666c1e35d0..25038310bfaf 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -142,8 +142,7 @@ def create_pass_instrument(pi_cls): @tvm._ffi.register_object("instrument.PassInstrument") class PassTimingInstrument(tvm.runtime.Object): - """A wrapper to create a passes time instrument that implemented in C++ - """ + """A wrapper to create a passes time instrument that implemented in C++""" def __init__(self): self.__init_handle_by_constructor__(_ffi_instrument_api.MakePassTimingInstrument) diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index d072a4b3ac98..63d22b1df2bc 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -274,6 +274,7 @@ def visit_constant(self, c): # Here we use :py::func`tvm.instrument.pass_instrument` decorator to implement # a PassInsturment class printing IR before execution of each passes: + @tvm.instrument.pass_instrument class PrintIR: """Print the name of the pass, the IR, only before passes execute.""" diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py index 6d8e2695647d..ed79d52c206c 100644 --- a/tutorials/dev/use_pass_instrument.py +++ b/tutorials/dev/use_pass_instrument.py @@ -51,9 +51,7 @@ num_of_image_class = 1000 image_shape = (3, 224, 224) output_shape = (batch_size, num_of_image_class) -relay_mod, relay_params = resnet.get_workload( - num_layers=18, batch_size=1, image_shape=image_shape -) +relay_mod, relay_params = resnet.get_workload(num_layers=18, batch_size=1, image_shape=image_shape) print(relay_mod.astext(show_meta_data=False)) From 98d4fbfba98cf63e51e3d1269404906d49480416 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Fri, 11 Jun 2021 22:24:38 +0800 Subject: [PATCH 10/18] Fix the order of tutorial. --- docs/conf.py | 2 +- tutorials/dev/use_pass_instrument.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d279c11b6c13..83fa8fd37ae9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -276,8 +276,8 @@ def git_describe_version(original_version): "dev": [ "low_level_custom_pass.py", "use_pass_infra.py", - "bring_your_own_datatypes.py", "use_pass_instrument.py", + "bring_your_own_datatypes.py", ], } diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py index ed79d52c206c..d487c8723180 100644 --- a/tutorials/dev/use_pass_instrument.py +++ b/tutorials/dev/use_pass_instrument.py @@ -91,8 +91,9 @@ # Note that ``exit_pass_ctx`` of ``PassTimingInstrument`` is called. # Profiles are cleared so nothing is printed. cur_pass_ctx.override_instruments([]) -profiles = timing_inst.render() -print(profiles) +# Uncomment the call to .render() to see a warning like: +# Warning: no passes have been profiled, did you enable pass profiling? +# profiles = timing_inst.render() ############################################################################### From 229991be46deba044ad06dc37f4b9f9260684984 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Sun, 13 Jun 2021 17:16:18 +0800 Subject: [PATCH 11/18] Add exception handling. Address feedbacks. --- docs/dev/pass_infra.rst | 137 +++++++++++++++++------- tutorials/dev/use_pass_instrument.py | 151 ++++++++++++++++++++++++++- 2 files changed, 248 insertions(+), 40 deletions(-) diff --git a/docs/dev/pass_infra.rst b/docs/dev/pass_infra.rst index b2a5923f3e88..e51749ff9156 100644 --- a/docs/dev/pass_infra.rst +++ b/docs/dev/pass_infra.rst @@ -110,7 +110,7 @@ required/disabled passes, etc. For instance, we may have a configuration which performs all passes at ``opt_level=3`` with some disabled passes using ``disabled_pass=xx`` provided by ``PassContext``. Now we could glob all passes at ``opt_level=3`` and exclude those in the disabled pass list. ``PassContext`` -also provides a way to instrument all passes. See section :ref:`pass_instrument_section_tag`. +also provides a way to instrument all passes. See section :ref:`pass_instrument_cpp_backend`. This class is designed for users to conveniently write the Python ``with`` syntax to perform optimizations under a certain configuration. In addition, the @@ -396,14 +396,45 @@ To allow other C++ modules to apply this pass, we declare a free function in TVM_DLL Pass FoldConstant(); -.. _pass_instrument_section_tag: +.. _pass_instrument_cpp_backend: Pass Instrument ~~~~~~~~~~~~~~~ -``PassInstrument`` provides callbacks run when entering/exiting ``PassContext`` and before/after executing passes. -Multiple ``PassInstrument`` instances can be registed into a single ``PassContext``. -Instrument instances are called sequentially in the order of ``instruments`` argument passed to ``PassContext``. +Currently we introduce four instrument point in the life-cycle of ``PassContext``. + +.. code:: c++ + + TVM_DLL void InstrumentEnterPassContext(); + TVM_DLL void InstrumentExitPassContext(); + TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; + TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; + +``InstrumentEnterPassContext`` is called immediately when entering the scope +of the ``PassContext`` instance. + +``InstrumentExitPassContext`` is called when leaving the scope of ``PassContext``, +or exceptions occur during the execution of passes. +This method is also called when instruments is being overriden by ``override_instruments`` in :py:class:`tvm.transform.PassContext`. +See :ref:`pass_instrument_overriden`. + +``InstrumentBeforePass`` is called before execution. +``InstrumentAfterPass`` is called after executioon if the pass should be run. The behavior is like: + +.. code:: c++ + + if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) { + new_ir_module = run_pass(ir_module, pass_ctx); + pass_ctx.InstrumentAfterPass(new_ir_module, pass_info); + return new_ir_module; + } + +The ``PassInstrument`` interface allow you to run arbitrary code inside above four methods. +Multiple ``PassInstrument`` instances can be registed into a single +``PassContext``. ``PassInstrument`` instances are called sequentially in the order of +``instruments`` argument passed to ``PassContext``. + +``PassInstrument`` provides following interfaces: .. code:: c++ @@ -427,36 +458,29 @@ Instrument instances are called sequentially in the order of ``instruments`` arg } // namespace instrument -Python interfaces are provided to implement ``PassInstrument`` quickly. - -Following four methods are invoked in the life-cycle of ``PassContext``. +Python frontend are provided to implement ``PassInstrument`` quickly. See :ref:`pass_instrument_py_frontend`. -.. code:: c++ - - TVM_DLL void InstrumentEnterPassContext(); - TVM_DLL void InstrumentExitPassContext(); - TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; - TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; +Within a ``PassContext``, the call sequence of a ``PassInstrument`` instance is like: -``InstrumentEnterPassContext`` is called immediately when the scope -of the ``PassContext`` instance is entered. +:: -``InstrumentExitPassContext`` is called when the scope of ``PassContextNode`` -is being leaved, or exceptions occur during the execution of passes. -This method is also called when instruments is being overriden by ``override_instruments`` in ::py:class:`tvm.transform.PassContext`. + with PassContext(instruments=[pi]) # pi = a PassInstrument implementation. + pi.EnterPassContext() -``InstrumentBeforePass`` is called before pass-execution. -``InstrumentAfterPass`` is called after pass-executioon if the pass should be run. The behavir is like: + if pi.ShouldRun(Pass1): + pi.RunBeforePass() + Pass1() + pi.RunAfterPass() -.. code:: c++ + if pi.ShouldRun(Pass2): + pi.RunBeforePass() + Pass2() + pi.RunAfterPass() - if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) { - new_ir_module = run_pass(ir_module, pass_ctx); - pass_ctx.InstrumentAfterPass(new_ir_module, pass_info); - return new_ir_module; - } + pi.ExitPassContext() -Here is a brief introduction of each methods. See (`src/ir/transform.cc`_) for more details. +Here is a brief introduction of relations between ``PassInstrument`` interfaces +and ``PassContext`` methods. See (`src/ir/transform.cc`_) for more details. - ``InstrumentEnterPassContext`` @@ -474,12 +498,11 @@ Here is a brief introduction of each methods. See (`src/ir/transform.cc`_) for m * ``ExitPassContext()`` of each ``PassInstrument`` instances are executed in the order of ``instruments`` passed to the ``PassContext``. * While an exception occurs, ``instruments`` is cleared. - * That means, instances registered after the one throwing exceptions do not execute ``ExitPassContext``. + * ``PassInstrument`` Instances registered after the one throwing exceptions do not execute ``ExitPassContext``. - ``InstrumentBeforePass`` - * ``ShouldRun`` callbakc is executed if the pass is not listed as a required pass. - If the pass is a required pass, ``ShouldRun`` will not be executed for that pass. + * ``ShouldRun`` is executed if the pass is not listed as a required pass. * ``RunBeforePass`` is executed in the order of ``instruments`` if the pass is not blocked by ``ShouldRun``. * Note that ``InstrumentBeforePass`` returns a boolean indicating whether or not the pass should be run. * When an exception occur, it is thrown immediately. @@ -492,6 +515,17 @@ Here is a brief introduction of each methods. See (`src/ir/transform.cc`_) for m * When an exception occur, it is thrown immediately. We rely on Python Context Manager or ``With`` class(`include/tvm/support/with.h`_) to exit ``PassContext`` safely +Built-in Instrument +^^^^^^^^^^^^^^^^^^^ + +There are several built-in instruments. Those marked with *TODO* are not implemented yet. + +PassTimmingInstrument (see `src/ir/instrument.cc`_) + +PrintBefore(TODO) + +PrintAfter(TODO) + Python Frontend ~~~~~~~~~~~~~~~ @@ -630,6 +664,9 @@ decorators and then invoke it. For more examples about how to customize your own optimization pipeline and debug Relay and tir passes, please refer to the `use pass infra`_ tutorial. + +.. _pass_instrument_py_frontend: + Pass Instrument ^^^^^^^^^^^^^^^ @@ -653,29 +690,53 @@ One can implement a ``PassInstrument`` by using the ``pass_instrument`` decorato - ``enter_pass_ctx`` - * This callback is run when entering ``PassContext``. + * This method is run when entering ``PassContext``. - ``exit_pass_ctx`` - * This callback is run when exiting ``PassContext``. + * This method is run when exiting ``PassContext``. - ``should_run`` - * This callback is run before a pass is executed. It returns a boolean + * This method is run before a pass is executed. It returns a boolean indicating whether or not the pass should be run. * If a pass is listed as required, ``should_run`` will not have effect and not be executed. - ``run_before_pass`` - * If a pass should be run, this callback is run just before pass execution. + * If a pass should be run, this method is run just before pass execution. - ``run_after_pass`` - * This callback is run right after a pass has been executed. - + * This method is run right after a pass has been executed. `use pass instrument`_ tutorial provides examples for how to implement ``PassInstrument`` with Python APIs. +.. _pass_instrument_overriden: + +Override Instruments in Current PassContext +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``override_instruments`` method is provided to override the ``instruments`` of current ``PassContext``. +For example, if passes are run without explicitly creating a new ``PassContext``, +one can still register ``PassInstrument`` into the global ``PassContext`` by: + +.. code:: python + # Get current PassContext + cur_pass_ctx = tvm.transform.PassContext.current() + # Register new PassInstrument instance + cur_pass_ctx.override_instruments([pass_inst0, pass_inst1]) + # Run Passes + mod = Pass1(mod) + mod = Pass2(mod) + # Get instrument results...e.t.c. + result0 = pass_inst0.get_result() + result1 = pass_inst1.get_result() + +Note that when ``override_instruments`` is called, the ``exit_pass_ctx`` method of +old ``PassInstrument`` instances are called. Then the ``enter_pass_ctx`` method of +new ``PassInstrument`` are called. + .. _Sequential: https://pytorch.org/docs/stable/nn.html?highlight=sequential#torch.nn.Sequential .. _Block: https://mxnet.apache.org/api/python/docs/api/gluon/block.html#gluon-block @@ -688,6 +749,8 @@ One can implement a ``PassInstrument`` by using the ``pass_instrument`` decorato .. _src/ir/transform.cc: https://github.com/apache/tvm/blob/main/src/ir/transform.cc +.. _src/ir/instrument.cc: https://github.com/apache/tvm/blob/main/src/ir/instrument.cc + .. _src/relay/transforms/fold_constant.cc: https://github.com/apache/tvm/blob/main/src/relay/transforms/fold_constant.cc .. _python/tvm/relay/transform/transform.py: https://github.com/apache/tvm/blob/main/python/tvm/relay/transform/transform.py diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py index d487c8723180..f889846c6a94 100644 --- a/tutorials/dev/use_pass_instrument.py +++ b/tutorials/dev/use_pass_instrument.py @@ -71,12 +71,14 @@ ############################################################################### +# Use Current PassContext With Instruments +# ---------------------------------------- # One can also use the current ``PassContext`` and register # ``PassInstrument`` instances by ``override_instruments`` method. -# Note that ``override_instruments`` executes ``exit_pass_ctx`` callbacks +# Note that ``override_instruments`` executes ``exit_pass_ctx`` method # if any instrument already exists. Then it switches to new instruments -# and calls ``enter_pass_ctx`` callbacks of new instruments. -# Refer to following sections and :py:func:`tvm.instrument.pass_instrument` for these callbacks. +# and calls ``enter_pass_ctx`` method of new instruments. +# Refer to following sections and :py:func:`tvm.instrument.pass_instrument` for these methods. cur_pass_ctx = tvm.transform.PassContext.current() cur_pass_ctx.override_instruments([timing_inst]) relay_mod = relay.transform.InferType()(relay_mod) @@ -227,3 +229,146 @@ def _diff(d_after, d_before): from pprint import pprint pprint(call_node_inst.get_pass_to_op_diff()) + + +############################################################################### +# Exception Handling +# ------------------ +# Let's see what happen if exceptions occur in each methods of a ``PassInstrument``. +# +# Define ``PassInstrument`` classes which raise exceptions in enter/exit ``PassContext``: +class PassExampleBase: + def __init__(self, name): + self._name = name + + def enter_pass_ctx(self): + print(self._name, "enter_pass_ctx") + + def exit_pass_ctx(self): + print(self._name, "exit_pass_ctx") + + def should_run(self, mod, info): + print(self._name, "should_run") + return True + + def run_before_pass(self, mod, pass_info): + print(self._name, "run_before_pass") + + def run_after_pass(self, mod, pass_info): + print(self._name, "run_after_pass") + + +@pass_instrument +class PassFine(PassExampleBase): + pass + + +@pass_instrument +class PassBadEnterCtx(PassExampleBase): + def enter_pass_ctx(self): + print(self._name, " bad enter_pass_ctx!!!") + raise ValueError("{} bad enter_pass_ctx".format(self._name)) + + +@pass_instrument +class PassBadExitCtx(PassExampleBase): + def exit_pass_ctx(self): + print(self._name, "bad exit_pass_ctx!!!") + raise ValueError("{} bad exit_pass_ctx".format(self._name)) + + +############################################################################### +# If an exception occur in ``enter_pass_ctx``, ``PassContext`` disable the pass +# instrumentation. And it will run ``exit_pass_ctx`` of each ``PassInstrument`` +# which successfully finished ``enter_pass_ctx``. +# +# In following example, we can see ``exit_pass_ctx`` of `PassFine_0` is executed after exception. +demo_ctx = tvm.transform.PassContext( + instruments=[ + PassFine("PassFine_0"), + PassBadEnterCtx("PassBadEnterCtx"), + PassFine("PassFine_1"), + ] +) +try: + with demo_ctx: + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Also, all ``PassInstrument`` are cleared. +# So nothing printed while ``override_instruments`` is called. +demo_ctx.override_instruments([]) # no PassFine_0 exit_pass_ctx printed....etc + +############################################################################### +# If an exception occur in ``exit_pass_ctx``, pass instrumentation is disabled. +# Then exception is thrown. That means ``PassInstrument`` registered +# after the one throwing the exception do not execute ``exit_pass_ctx``. +demo_ctx = tvm.transform.PassContext( + instruments=[ + PassFine("PassFine_0"), + PassBadExitCtx("PassBadExitCtx"), + PassFine("PassFine_1"), + ] +) +try: + # PassFine_1 execute enter_pass_ctx, but not exit_pass_ctx. + with demo_ctx: + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Exceptions occured in ``should_run``, ``run_before_pass``, ``run_after_pass`` +# are not handled explitcitly -- that means, we rely on the context manager +# (the ``with`` syntax) to exit ``PassContext`` safely. +# +# We use ``run_before_pass`` as an example: +@pass_instrument +class PassBadRunBefore(PassExampleBase): + def run_before_pass(self, mod, pass_info): + print(self._name, "bad run_before_pass!!!") + raise ValueError("{} bad run_before_pass".format(self._name)) + + +demo_ctx = tvm.transform.PassContext( + instruments=[ + PassFine("PassFine_0"), + PassBadRunBefore("PassBadRunBefore"), + PassFine("PassFine_1"), + ] +) +try: + # All exit_pass_ctx are called. + with demo_ctx: + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Also note that pass instrumentation is not disable. So if we call +# ``override_instruments``, the ``exit_pass_ctx`` of old registered ``PassInstrument`` +# is called. +demo_ctx.override_instruments([]) + +############################################################################### +# If we don't wrap pass execution with ``with`` syntax, ``exit_pass_ctx`` is not +# called. Let try this with current ``PassContext``: +cur_pass_ctx = tvm.transform.PassContext.current() +cur_pass_ctx.override_instruments( + [ + PassFine("PassFine_0"), + PassBadRunBefore("PassBadRunBefore"), + PassFine("PassFine_1"), + ] +) + +############################################################################### +# Then call passes. ``exit_pass_ctx`` is not executed after the exception, +# as expectation. +try: + # No ``exit_pass_ctx`` got executed. + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) From 80aaa37f1cafba1b31a3cee4b0ff1f93b1dba9b3 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Mon, 14 Jun 2021 16:30:13 +0800 Subject: [PATCH 12/18] Fix CI error -- clearing instruments in global pass_ctx --- tutorials/dev/use_pass_instrument.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py index f889846c6a94..544bb18f97a7 100644 --- a/tutorials/dev/use_pass_instrument.py +++ b/tutorials/dev/use_pass_instrument.py @@ -266,7 +266,7 @@ class PassFine(PassExampleBase): @pass_instrument class PassBadEnterCtx(PassExampleBase): def enter_pass_ctx(self): - print(self._name, " bad enter_pass_ctx!!!") + print(self._name, "bad enter_pass_ctx!!!") raise ValueError("{} bad enter_pass_ctx".format(self._name)) @@ -321,7 +321,7 @@ def exit_pass_ctx(self): ############################################################################### # Exceptions occured in ``should_run``, ``run_before_pass``, ``run_after_pass`` -# are not handled explitcitly -- that means, we rely on the context manager +# are not handled explicitly -- that means, we rely on the context manager # (the ``with`` syntax) to exit ``PassContext`` safely. # # We use ``run_before_pass`` as an example: @@ -372,3 +372,7 @@ def run_before_pass(self, mod, pass_info): relay_mod = relay.transform.InferType()(relay_mod) except ValueError as ex: print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Clear instruments. +cur_pass_ctx.override_instruments([]) From b6083aa1d0f8a117a07bf7e6adbab14c4c3f67c2 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Mon, 14 Jun 2021 22:15:58 +0800 Subject: [PATCH 13/18] Clarify section hierachy. --- docs/dev/pass_infra.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/dev/pass_infra.rst b/docs/dev/pass_infra.rst index e51749ff9156..ab27febb326c 100644 --- a/docs/dev/pass_infra.rst +++ b/docs/dev/pass_infra.rst @@ -345,7 +345,7 @@ favorably use Python APIs to create a specific pass object. Pass Sequential(tvm::Array passes, PassInfo pass_info); Pass Registration -~~~~~~~~~~~~~~~~~ +^^^^^^^^^^^^^^^^^ We've covered the concept of different level of passes and the context used for compilation. It would be interesting to see how easily users can register @@ -399,7 +399,7 @@ To allow other C++ modules to apply this pass, we declare a free function in .. _pass_instrument_cpp_backend: Pass Instrument -~~~~~~~~~~~~~~~ +^^^^^^^^^^^^^^^ Currently we introduce four instrument point in the life-cycle of ``PassContext``. @@ -526,7 +526,6 @@ PrintBefore(TODO) PrintAfter(TODO) - Python Frontend ~~~~~~~~~~~~~~~ From 2c201c1844dbe7f22d37b3ff14444837839d1680 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Tue, 15 Jun 2021 15:31:45 +0800 Subject: [PATCH 14/18] Emphasize to use decorator instead of subclassing --- docs/dev/pass_infra.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/dev/pass_infra.rst b/docs/dev/pass_infra.rst index ab27febb326c..e6d636a30bb7 100644 --- a/docs/dev/pass_infra.rst +++ b/docs/dev/pass_infra.rst @@ -401,7 +401,7 @@ To allow other C++ modules to apply this pass, we declare a free function in Pass Instrument ^^^^^^^^^^^^^^^ -Currently we introduce four instrument point in the life-cycle of ``PassContext``. +Currently we introduce four instrument points in the life-cycle of ``PassContext``. .. code:: c++ @@ -685,7 +685,10 @@ A customizable framework to instrument passes is provided. ``PassInstrument`` cl ): # ... -One can implement a ``PassInstrument`` by using the ``pass_instrument`` decorator(`python/tvm/ir/instrument.py`_) on a class implementing following methods: +One can implement a ``PassInstrument`` by using the ``pass_instrument`` +decorator(`python/tvm/ir/instrument.py`_) on a class implementing following methods. +Note that it is recommended to use the ``pass_instrument`` decorator to implement +``PassInstrument``, instead of overriding or subclassing. - ``enter_pass_ctx`` From b0a270f47bbfaa5f71e5d677281766839ccf2c08 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Mon, 28 Jun 2021 10:58:17 +0800 Subject: [PATCH 15/18] Add a sentence to explain Pass Instrument. Fix typo. --- docs/dev/pass_infra.rst | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/docs/dev/pass_infra.rst b/docs/dev/pass_infra.rst index e6d636a30bb7..7ba0c3264b06 100644 --- a/docs/dev/pass_infra.rst +++ b/docs/dev/pass_infra.rst @@ -401,7 +401,11 @@ To allow other C++ modules to apply this pass, we declare a free function in Pass Instrument ^^^^^^^^^^^^^^^ -Currently we introduce four instrument points in the life-cycle of ``PassContext``. +Pass Instrument is a mechanism to analyze the pass itself. For example, +we can use the infrastructure to know how much time and memory a pass requires +or how a pass can transform the IR module. + +We introduce four instrument points in the life-cycle of ``PassContext``. .. code:: c++ @@ -419,7 +423,7 @@ This method is also called when instruments is being overriden by ``override_ins See :ref:`pass_instrument_overriden`. ``InstrumentBeforePass`` is called before execution. -``InstrumentAfterPass`` is called after executioon if the pass should be run. The behavior is like: +``InstrumentAfterPass`` is called after execution if the pass should be run. The behavior is like: .. code:: c++ @@ -520,11 +524,19 @@ Built-in Instrument There are several built-in instruments. Those marked with *TODO* are not implemented yet. -PassTimmingInstrument (see `src/ir/instrument.cc`_) +- PassTimingInstrument (see `src/ir/instrument.cc`_) + + * Profile the execution time of passes. + +- PrintIRBefore(TODO) + + * Print the IR module before the pass transforms it. :py:func:`tvm.transform.PrintIR` + can also serve this purpose if we insert it around passes. However, + with the ``PassInstrument``, we don't need to modify the sequence of passes. -PrintBefore(TODO) +- PrintAfter(TODO) -PrintAfter(TODO) + * Print the IR module after the pass transforms it. Python Frontend ~~~~~~~~~~~~~~~ From b5cffefe1eb12fc3314886d8089eee31cee51532 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Mon, 28 Jun 2021 11:41:17 +0800 Subject: [PATCH 16/18] Shrink python docs a little. --- docs/dev/pass_infra.rst | 36 +++++++++--------------------------- 1 file changed, 9 insertions(+), 27 deletions(-) diff --git a/docs/dev/pass_infra.rst b/docs/dev/pass_infra.rst index 7ba0c3264b06..8973679c3c55 100644 --- a/docs/dev/pass_infra.rst +++ b/docs/dev/pass_infra.rst @@ -681,22 +681,6 @@ optimization pipeline and debug Relay and tir passes, please refer to the Pass Instrument ^^^^^^^^^^^^^^^ -A customizable framework to instrument passes is provided. ``PassInstrument`` classes can be registered while constructing ``PassContext``. - -.. code:: python - - @tvm._ffi.register_object("transform.PassContext") - class PassContext(tvm.runtime.Object): - def __init__( - self, - opt_level=2, - required_pass=None, - disabled_pass=None, - instruments=None, - config=None, - ): - # ... - One can implement a ``PassInstrument`` by using the ``pass_instrument`` decorator(`python/tvm/ir/instrument.py`_) on a class implementing following methods. Note that it is recommended to use the ``pass_instrument`` decorator to implement @@ -712,9 +696,8 @@ Note that it is recommended to use the ``pass_instrument`` decorator to implemen - ``should_run`` - * This method is run before a pass is executed. It returns a boolean + * This method is run before a pass is executed, returning a boolean indicating whether or not the pass should be run. - * If a pass is listed as required, ``should_run`` will not have effect and not be executed. - ``run_before_pass`` @@ -724,6 +707,9 @@ Note that it is recommended to use the ``pass_instrument`` decorator to implemen * This method is run right after a pass has been executed. +``PassInstrument`` instances can be registered through ``instruments`` argument in +:py:class:`tvm.transform.PassContext`. + `use pass instrument`_ tutorial provides examples for how to implement ``PassInstrument`` with Python APIs. .. _pass_instrument_overriden: @@ -736,16 +722,12 @@ For example, if passes are run without explicitly creating a new ``PassContext`` one can still register ``PassInstrument`` into the global ``PassContext`` by: .. code:: python - # Get current PassContext + cur_pass_ctx = tvm.transform.PassContext.current() - # Register new PassInstrument instance - cur_pass_ctx.override_instruments([pass_inst0, pass_inst1]) - # Run Passes - mod = Pass1(mod) - mod = Pass2(mod) - # Get instrument results...e.t.c. - result0 = pass_inst0.get_result() - result1 = pass_inst1.get_result() + # override PassInstrument instances + cur_pass_ctx.override_instruments([pass_inst]) + mod = pass_seq(mod) + result = pass_inst.get_result() Note that when ``override_instruments`` is called, the ``exit_pass_ctx`` method of old ``PassInstrument`` instances are called. Then the ``enter_pass_ctx`` method of From fd326227826ce5399ec675c380f0ac08c31a4b25 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Tue, 29 Jun 2021 09:23:25 +0800 Subject: [PATCH 17/18] Fix tag name. --- tutorials/dev/use_pass_infra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 63d22b1df2bc..468c4d40b942 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -268,7 +268,7 @@ def visit_constant(self, c): # # There is a more flexible debugging mechanism. One can implement a ``PassInstrument`` # class to execute arbitrary code not only before and/or after each pass but also -# at entering/exiting ``PassContext``. See :ref:`pass_instrument_section_tag` +# at entering/exiting ``PassContext``. See :ref:`pass_instrument_cpp_backend` # for more details. # # Here we use :py::func`tvm.instrument.pass_instrument` decorator to implement From 078db705249a21f351056810f92c8317bbc288eb Mon Sep 17 00:00:00 2001 From: chiwwang Date: Tue, 29 Jun 2021 10:10:34 +0800 Subject: [PATCH 18/18] Address feedbacks. --- python/tvm/ir/instrument.py | 3 -- tutorials/dev/use_pass_instrument.py | 80 +++++++++++++--------------- 2 files changed, 37 insertions(+), 46 deletions(-) diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index 25038310bfaf..1948a6787eac 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -93,7 +93,6 @@ def pass_instrument(pi_cls=None): Examples -------- - The following code block shows how to decorate a pass instrument class. .. code-block:: python @@ -158,8 +157,6 @@ def render(): Examples -------- - The following code-block shows how to use this instrument. - .. code-block:: python timing_inst = PassTimingInstrument() diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py index 544bb18f97a7..3369304a651d 100644 --- a/tutorials/dev/use_pass_instrument.py +++ b/tutorials/dev/use_pass_instrument.py @@ -23,11 +23,12 @@ **Author**: `Chi-Wei Wang `_ As more and more passes are implemented, it becomes useful to instrument -passes execution, analyze per-pass effects and observe various events. -Pass infrastructure provides instrument mechanism. One can pass a list of -instrument instances to :py:class:`tvm.transform.PassContext`. -Also a decorator :py:func:`tvm.instrument.pass_instrument` is provided -to easily implement instrument classes. +pass execution, analyze per-pass effects, and observe various events. + +We can instrument passes by providing a list of :py:class:`tvm.ir.instrument.PassInstrument` +instances to :py:class:`tvm.transform.PassContext`. We provide a pass instrument +for collecting timing information (:py:class:`tvm.ir.instrument.PassTimingInstrument`), +but an extension mechanism is available via the :py:func:`tvm.instrument.pass_instrument` decorator. This tutorial demostrates how developers can use ``PassContext`` to instrument passes. Please also refer to the :ref:`pass-infra`. @@ -52,21 +53,23 @@ image_shape = (3, 224, 224) output_shape = (batch_size, num_of_image_class) relay_mod, relay_params = resnet.get_workload(num_layers=18, batch_size=1, image_shape=image_shape) +print("Printing the IR module...") print(relay_mod.astext(show_meta_data=False)) ############################################################################### # Create PassContext With Instruments # ----------------------------------- -# It is as simple as passing ``instruments`` argument to ``PassContext`` constructor. -# A built-in ``PassTimingInstrument`` is used to profile the execution time of -# each passes. +# To run all passes with an instrument, pass it via the ``instruments`` argument to +# the ``PassContext`` constructor. A built-in ``PassTimingInstrument`` is used to +# profile the execution time of each passes. timing_inst = PassTimingInstrument() with tvm.transform.PassContext(instruments=[timing_inst]): relay_mod = relay.transform.InferType()(relay_mod) relay_mod = relay.transform.FoldScaleAxis()(relay_mod) # before exiting the context, get profile results. profiles = timing_inst.render() +print("Printing results of timing profile...") print(profiles) @@ -84,11 +87,12 @@ relay_mod = relay.transform.InferType()(relay_mod) relay_mod = relay.transform.FoldScaleAxis()(relay_mod) profiles = timing_inst.render() +print("Printing results of timing profile...") print(profiles) ############################################################################### -# Register empty list to clear instruments. +# Register empty list to clear existing instruments. # # Note that ``exit_pass_ctx`` of ``PassTimingInstrument`` is called. # Profiles are cleared so nothing is printed. @@ -101,13 +105,14 @@ ############################################################################### # Create Customized Instrument Class # ---------------------------------- -# A customized instrument class can be easily created by +# A customized instrument class can be created using the # :py:func:`tvm.instrument.pass_instrument` decorator. # -# Let's create an instrument class which calculate the difference of ``CallNode`` -# counting per ``op.name`` before and after passes. +# Let's create an instrument class which calculates the change in number of +# occurrences of each operator caused by each pass. We can look at ``op.name`` to +# find the name of each operator. And we do this before and after passes to calculate the difference. + -# decorate the class @pass_instrument class RelayCallNodeDiffer: def __init__(self): @@ -149,26 +154,26 @@ def get_pass_to_op_diff(self): @staticmethod def _count_nodes(mod): + """Count the number of occurrences of each operator in the module""" ret = {} def visit(node): if isinstance(node, relay.expr.Call): - try: + if hasattr(node.op, "name"): op_name = node.op.name - except AttributeError: + else: # Some CallNode may not have 'name' such as relay.Function return - try: - ret[op_name] += 1 - except KeyError: - ret[op_name] = 1 + ret[op_name] = ret.get(op_name, 0) + 1 relay.analysis.post_order_visit(mod["main"], visit) return ret @staticmethod def _diff(d_after, d_before): - # d_after - d_before + """Calculate the difference of two dictionary along their keys. + The result is values in d_after minus values in d_before. + """ ret = {} key_after, key_before = set(d_after), set(d_before) for k in key_before & key_after: @@ -185,13 +190,7 @@ def _diff(d_after, d_before): ############################################################################### # Apply Passes and Multiple Instrument Classes # -------------------------------------------- -# Apply any pass you wish. Here :py:class:`tvm.relay.transform.ConvertLayout` -# and :py:class:`tvm.relay.transform.FoldConstant` are used. -# -# ``ConvertLayout`` might add ``layout_transform`` Op while ``FoldConstant`` can -# reduce the number of ``CallNode``. -# -# We can also use multiple instrument classes in a ``PassContext``. +# We can use multiple instrument classes in a ``PassContext``. # However, it should be noted that instrument methods are executed sequentially, # obeying the order of ``instruments`` argument. # So for instrument classes like ``PassTimingInstrument``, it is inevitable to @@ -201,11 +200,6 @@ def _diff(d_after, d_before): desired_layouts = { "nn.conv2d": ["NHWC", "HWIO"], } -# Because layout_transform may be added as a successor of Constant, -# we run FoldConstant twice. -# Though it is obvious only the FoldConstant after the ConvertLayout matter, -# we want to show how many layout_transform is added as a successor of -# Constant. pass_seq = tvm.transform.Sequential( [ relay.transform.FoldConstant(), @@ -213,7 +207,6 @@ def _diff(d_after, d_before): relay.transform.FoldConstant(), ] ) -# bind parameters to make VarNode as ConstantNode. relay_mod["main"] = bind_params_by_name(relay_mod["main"], relay_params) # timing_inst is put after call_node_inst. # So the execution time of ``call_node.inst.run_after_pass()`` is also counted. @@ -228,15 +221,16 @@ def _diff(d_after, d_before): # We can see how many CallNode increase/decrease per op type. from pprint import pprint +print("Printing the change in number of occurrences of each operator caused by each pass...") pprint(call_node_inst.get_pass_to_op_diff()) ############################################################################### # Exception Handling # ------------------ -# Let's see what happen if exceptions occur in each methods of a ``PassInstrument``. +# Let's see what happens if an exception occurs in a method of a ``PassInstrument``. # -# Define ``PassInstrument`` classes which raise exceptions in enter/exit ``PassContext``: +# Define ``PassInstrument`` classes which raise exceptions in enter/exit ``PassContext``: class PassExampleBase: def __init__(self, name): self._name = name @@ -278,8 +272,8 @@ def exit_pass_ctx(self): ############################################################################### -# If an exception occur in ``enter_pass_ctx``, ``PassContext`` disable the pass -# instrumentation. And it will run ``exit_pass_ctx`` of each ``PassInstrument`` +# If an exception occurs in ``enter_pass_ctx``, ``PassContext`` will disable the pass +# instrumentation. And it will run the ``exit_pass_ctx`` of each ``PassInstrument`` # which successfully finished ``enter_pass_ctx``. # # In following example, we can see ``exit_pass_ctx`` of `PassFine_0` is executed after exception. @@ -297,13 +291,13 @@ def exit_pass_ctx(self): print("Catching", str(ex).split("\n")[-1]) ############################################################################### -# Also, all ``PassInstrument`` are cleared. -# So nothing printed while ``override_instruments`` is called. +# Exceptions in ``PassInstrument`` instances cause all instruments of the current ``PassContext`` +# to be cleared, so nothing is printed when ``override_instruments`` is called. demo_ctx.override_instruments([]) # no PassFine_0 exit_pass_ctx printed....etc ############################################################################### -# If an exception occur in ``exit_pass_ctx``, pass instrumentation is disabled. -# Then exception is thrown. That means ``PassInstrument`` registered +# If an exception occurs in ``exit_pass_ctx``, then the pass instrument is disabled. +# Then exception is propagated. That means ``PassInstrument`` instances registered # after the one throwing the exception do not execute ``exit_pass_ctx``. demo_ctx = tvm.transform.PassContext( instruments=[ @@ -321,8 +315,8 @@ def exit_pass_ctx(self): ############################################################################### # Exceptions occured in ``should_run``, ``run_before_pass``, ``run_after_pass`` -# are not handled explicitly -- that means, we rely on the context manager -# (the ``with`` syntax) to exit ``PassContext`` safely. +# are not handled explicitly -- we rely on the context manager (the ``with`` syntax) +# to exit ``PassContext`` safely. # # We use ``run_before_pass`` as an example: @pass_instrument