From 5a444e5f5aa3c77b10000f010058504627d93064 Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Thu, 11 Mar 2021 10:55:23 +0800
Subject: [PATCH 01/12] [AutoScheduler] Fix incorrectly array context device
 and hide info at the beginning

---
 python/tvm/auto_scheduler/measure.py           | 4 ++--
 python/tvm/auto_scheduler/relay_integration.py | 5 +++++
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 959a9c5da82a..2e81f27b2772 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -868,7 +868,7 @@ def _timed_eval_func(
                 if arg in tensor_input_map:
                     tensor_name = tensor_input_map[arg]
                     if tensor_name in task_input_names:
-                        args.append(get_task_input_buffer(inp.task.workload_key, tensor_name))
+                        args.append(ndarray.array(get_task_input_buffer(inp.task.workload_key, tensor_name), ctx))
                         task_inputs_count += 1
                     else:
                         raise ValueError(
@@ -1079,7 +1079,7 @@ def _timed_rpc_run(
                 if arg in tensor_input_map:
                     tensor_name = tensor_input_map[arg]
                     if tensor_name in task_input_names:
-                        args.append(get_task_input_buffer(inp.task.workload_key, tensor_name))
+                        args.append(ndarray.array(get_task_input_buffer(inp.task.workload_key, tensor_name), ctx))
                         task_inputs_count += 1
                     else:
                         raise ValueError(
diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index 68f53125c7ae..c20fc3d4732d 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -117,12 +117,17 @@ def extract_tasks(
     env = TracingEnvironment(
         TracingMode.EXTRACT_TASK if include_simple_tasks else TracingMode.EXTRACT_COMPLEX_TASK_ONLY
     )
+    
+    dispatch_ctx = DispatchContext.current
+    old_verbose = dispatch_ctx.verbose
+    dispatch_ctx.verbose = 0
     with env:
         # Wrap build call in a new thread to avoid the conflict
         # between python's multiprocessing and tvm's thread pool
         build_thread = threading.Thread(target=call_all_topi_funcs, args=(mod, params, target))
         build_thread.start()
         build_thread.join()
+    dispatch_ctx.verbose = old_verbose
 
     # create search tasks
     tasks = []

From 49e9b60fc2b921895973c4b650aa49be85a9994a Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Thu, 11 Mar 2021 11:20:46 +0800
Subject: [PATCH 02/12] Lint fix

---
 python/tvm/auto_scheduler/measure.py          | 12 +++++--
 .../tvm/auto_scheduler/relay_integration.py   |  2 +-
 .../unittest/test_auto_scheduler_measure.py   | 36 +++++++++++++++++--
 3 files changed, 45 insertions(+), 5 deletions(-)

diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 2e81f27b2772..d02dcff3bba0 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -868,7 +868,11 @@ def _timed_eval_func(
                 if arg in tensor_input_map:
                     tensor_name = tensor_input_map[arg]
                     if tensor_name in task_input_names:
-                        args.append(ndarray.array(get_task_input_buffer(inp.task.workload_key, tensor_name), ctx))
+                        args.append(
+                            ndarray.array(
+                                get_task_input_buffer(inp.task.workload_key, tensor_name), ctx
+                            )
+                        )
                         task_inputs_count += 1
                     else:
                         raise ValueError(
@@ -1079,7 +1083,11 @@ def _timed_rpc_run(
                 if arg in tensor_input_map:
                     tensor_name = tensor_input_map[arg]
                     if tensor_name in task_input_names:
-                        args.append(ndarray.array(get_task_input_buffer(inp.task.workload_key, tensor_name), ctx))
+                        args.append(
+                            ndarray.array(
+                                get_task_input_buffer(inp.task.workload_key, tensor_name), ctx
+                            )
+                        )
                         task_inputs_count += 1
                     else:
                         raise ValueError(
diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index c20fc3d4732d..6cce30f2f559 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -117,7 +117,7 @@ def extract_tasks(
     env = TracingEnvironment(
         TracingMode.EXTRACT_TASK if include_simple_tasks else TracingMode.EXTRACT_COMPLEX_TASK_ONLY
     )
-    
+
     dispatch_ctx = DispatchContext.current
     old_verbose = dispatch_ctx.verbose
     dispatch_ctx.verbose = 0
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index 116981028cc9..e21459a9b302 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -357,7 +357,7 @@ def test_measure_target_host():
 
 
 @tvm.testing.requires_llvm
-def test_measure_special_inputs_map_by_name():
+def test_measure_special_inputs_map_by_name_local_runner():
     @auto_scheduler.register_workload
     def foo():
         X = te.placeholder(shape=[10], dtype="int32")
@@ -384,6 +384,37 @@ def foo():
     assert mress[0].error_no == 0
 
 
+@tvm.testing.requires_llvm
+def test_measure_special_inputs_map_by_name_rpc_runner():
+    @auto_scheduler.register_workload
+    def foo():
+        X = te.placeholder(shape=[10], dtype="int32")
+        Index = te.placeholder(shape=[1], dtype="int32", name="Index")
+        Y = te.compute((1,), lambda i: X[Index[i]])
+        return [X, Index, Y]
+
+    # This workload cannot use random input for the `Index` input
+    task = auto_scheduler.SearchTask(
+        func=foo,
+        target="llvm",
+        task_inputs={
+            "Index": tvm.nd.array(np.array([5], dtype="int32")),
+        },
+    )
+
+    minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
+    local_builder = auto_scheduler.LocalBuilder()
+    measure_ctx = auto_scheduler.LocalRPCMeasureContext(
+        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+    )
+    rpc_runner = measure_ctx.runner
+
+    bress = local_builder.build([minp])
+    assert bress[0].error_no == 0
+    mress = rpc_runner.run([minp], bress)
+    assert mress[0].error_no == 0
+
+
 if __name__ == "__main__":
     test_record_split_reorder_fuse_annotation()
     test_record_compute_at_root_inline_cache_read_write()
@@ -395,4 +426,5 @@ def foo():
     test_dag_measure_local_builder_runner()
     test_measure_local_builder_rpc_runner()
     test_measure_target_host()
-    test_measure_special_inputs_map_by_name()
+    test_measure_special_inputs_map_by_name_local_runner()
+    test_measure_special_inputs_map_by_name_rpc_runner()

From 0015e8881f0040e26e13f4cbe58b218fa009dfdc Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Thu, 11 Mar 2021 11:54:53 +0800
Subject: [PATCH 03/12] Lint fix

---
 .../unittest/test_auto_scheduler_measure.py   | 21 ++++++++++---------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index e21459a9b302..7605b70be6f4 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -402,17 +402,18 @@ def foo():
         },
     )
 
-    minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
-    local_builder = auto_scheduler.LocalBuilder()
-    measure_ctx = auto_scheduler.LocalRPCMeasureContext(
-        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
-    )
-    rpc_runner = measure_ctx.runner
+    for enable_cpu_cache_flush in [True, False]:
+        minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
+        local_builder = auto_scheduler.LocalBuilder()
+        measure_ctx = auto_scheduler.LocalRPCMeasureContext(
+            timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+        )
+        rpc_runner = measure_ctx.runner
 
-    bress = local_builder.build([minp])
-    assert bress[0].error_no == 0
-    mress = rpc_runner.run([minp], bress)
-    assert mress[0].error_no == 0
+        bress = local_builder.build([minp])
+        assert bress[0].error_no == 0
+        mress = rpc_runner.run([minp], bress)
+        assert mress[0].error_no == 0
 
 
 if __name__ == "__main__":

From c872407f9f5f88984e68041674871c4ca0fa297f Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Wed, 14 Apr 2021 10:42:44 +0800
Subject: [PATCH 04/12] update repo

---
 python/tvm/auto_scheduler/measure.py | 84 +++++++++-------------------
 1 file changed, 26 insertions(+), 58 deletions(-)

diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index d02dcff3bba0..9f0b97749a28 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -17,17 +17,13 @@
 
 """
 Distributed measurement infrastructure to measure the runtime costs of tensor programs.
-
 These functions are responsible for building the tvm module, uploading it to
 remote devices, recording the running time costs, and checking the correctness of the output.
-
 We separate the measurement into two steps: build and run.
 A builder builds the executable binary files and a runner runs the binary files to
 get the measurement results. The flow of data structures is
-
   .               `ProgramBuilder`                 `ProgramRunner`
   `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
-
 We implement these in python to utilize python's multiprocessing and error handling.
 """
 
@@ -44,6 +40,8 @@
 from tvm.ir import transform
 from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
 from tvm.contrib import tar, ndk
+from tvm.target import Target
+
 
 from . import _ffi_api
 from .loop_state import StateObject
@@ -97,7 +95,6 @@ def callback_func(policy, inputs, results):
 
     def callback(self, policy, inputs, results):
         """The callback function.
-
         Parameters
         ----------
         policy: auto_scheduler.search_policy.SearchPolicy
@@ -113,7 +110,6 @@ def callback(self, policy, inputs, results):
 @tvm._ffi.register_object("auto_scheduler.MeasureInput")
 class MeasureInput(Object):
     """Store the input of a measurement.
-
     Parameters
     ----------
     task : SearchTask
@@ -129,7 +125,6 @@ def __init__(self, task, state):
     def serialize(self):
         """Custom serialization to workaround MeasureInput not exposing all its
         members to the TVM ffi interface.
-
         Note that we do not implement __getstate__ as it does not seem to work
         with initialization of the workload registry (maybe because of
         initialization order?).
@@ -149,7 +144,6 @@ def deserialize(data):
 @tvm._ffi.register_object("auto_scheduler.BuildResult")
 class BuildResult(Object):
     """Store the result of a build.
-
     Parameters
     ----------
     filename : Optional[str]
@@ -176,7 +170,6 @@ def __init__(self, filename, args, error_no, error_msg, time_cost):
 @tvm._ffi.register_object("auto_scheduler.MeasureResult")
 class MeasureResult(Object):
     """Store the results of a measurement.
-
     Parameters
     ----------
     costs : List[float]
@@ -204,14 +197,12 @@ def recover_measure_input(inp, rebuild_state=False):
     Recover a deserialized MeasureInput by rebuilding the missing fields.
     1. Rebuid the compute_dag in inp.task
     2. (Optional) Rebuild the stages in inp.state
-
     Parameters
     ----------
     inp: MeasureInput
         The deserialized MeasureInput
     rebuild_state: bool = False
         Whether rebuild the stages in MeasureInput.State
-
     Returns
     -------
     new_input: MeasureInput
@@ -221,10 +212,12 @@ def recover_measure_input(inp, rebuild_state=False):
     from .search_task import SearchTask  # lazily import to avoid recursive dependency
 
     task = inp.task
+    task.target, task.target_host = Target.check_and_update_host_consist(
+        task.target, task.target_host
+    )
     new_task = SearchTask(
         workload_key=task.workload_key,
         target=task.target,
-        target_host=task.target_host,
         hardware_params=task.hardware_params,
         layout_rewrite_option=task.layout_rewrite_option,
         task_inputs=list(task.task_input_names),
@@ -244,14 +237,12 @@ class ProgramBuilder(Object):
 
     def build(self, measure_inputs, verbose=1):
         """Build programs and return results.
-
         Parameters
         ----------
         measure_inputs : List[MeasureInput]
             A List of MeasureInput.
         verbose: int = 1
             Verbosity level. 0 for silent, 1 to output information during program building.
-
         Returns
         -------
         res : List[BuildResult]
@@ -265,7 +256,6 @@ class ProgramRunner(Object):
 
     def run(self, measure_inputs, build_results, verbose=1):
         """Run measurement and return results.
-
         Parameters
         ----------
         measure_inputs : List[MeasureInput]
@@ -274,7 +264,6 @@ def run(self, measure_inputs, build_results, verbose=1):
             A List of BuildResult to be ran.
         verbose: int = 1
             Verbosity level. 0 for silent, 1 to output information during program running.
-
         Returns
         -------
         res : List[MeasureResult]
@@ -287,7 +276,6 @@ class ProgramMeasurer(Object):
     """
     Measurer that measures the time costs of tvm programs
     This class combines ProgramBuilder and ProgramRunner, and provides a simpler API.
-
     Parameters
     ----------
     builder : ProgramBuilder
@@ -312,7 +300,6 @@ def __init__(self, builder, runner, callbacks, verbose, max_continuous_error=Non
 @tvm._ffi.register_object("auto_scheduler.LocalBuilder")
 class LocalBuilder(ProgramBuilder):
     """LocalBuilder use local CPU cores to build programs in parallel.
-
     Parameters
     ----------
     timeout : int = 15
@@ -347,7 +334,6 @@ def __init__(self, timeout=15, n_parallel=multiprocessing.cpu_count(), build_fun
 @tvm._ffi.register_object("auto_scheduler.LocalRunner")
 class LocalRunner(ProgramRunner):
     """LocalRunner that uses local CPU/GPU to measures the time cost of programs.
-
     Parameters
     ----------
     timeout : int = 10
@@ -408,7 +394,6 @@ class RPCRunner(ProgramRunner):
     """RPCRunner that uses RPC call to measures the time cost of programs on remote devices.
     Or sometime we may need to use RPC even in local running to insulate the thread environment.
     (e.g. running CUDA programs)
-
     Parameters
     ----------
     key : str
@@ -493,7 +478,6 @@ def __init__(
 class LocalRPCMeasureContext:
     """A context wrapper for running RPCRunner locally.
     This will launch a local RPC Tracker and local RPC Server.
-
     Parameters
     ----------
     priority : int = 1
@@ -544,9 +528,9 @@ def __init__(
         from tvm.rpc.tracker import Tracker
         from tvm.rpc.server import Server
 
-        ctx = tvm.context("cuda", 0)
-        if ctx.exist:
-            cuda_arch = "sm_" + "".join(ctx.compute_version.split("."))
+        dev = tvm.device("cuda", 0)
+        if dev.exist:
+            cuda_arch = "sm_" + "".join(dev.compute_version.split("."))
             set_cuda_target_arch(cuda_arch)
         host = "0.0.0.0"
         self.tracker = Tracker(host, port=9000, port_end=10000, silent=True)
@@ -602,6 +586,9 @@ def _timed_func(inp_serialized, build_func, verbose):
     tic = time.time()
     inp = MeasureInput.deserialize(inp_serialized)
     task = inp.task
+    task.target, task.target_host = Target.check_and_update_host_consist(
+        task.target, task.target_host
+    )
 
     error_no = MeasureErrorNo.NO_ERROR
     error_msg = None
@@ -622,9 +609,7 @@ def _timed_func(inp_serialized, build_func, verbose):
 
         try:
             with transform.PassContext():
-                func = build_module.build(
-                    sch, args, target=task.target, target_host=task.target_host
-                )
+                func = build_module.build(sch, args, target=task.target)
             func.export_library(filename, build_func)
         # pylint: disable=broad-except
         except Exception:
@@ -645,12 +630,10 @@ def _timed_func(inp_serialized, build_func, verbose):
 def local_build_worker(args):
     """
     Build function of LocalBuilder to be ran in the Builder thread pool.
-
     Parameters
     ----------
     args: Tuple[MeasureInput, str, int, int]
         inputs, build-func, time, verbose args passed to local_builder_build
-
     Returns
     -------
     res : BuildResult
@@ -679,7 +662,6 @@ def local_build_worker(args):
 def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbose=1):
     """
     Build function of LocalBuilder to build the MeasureInputs to runnable modules.
-
     Parameters
     ----------
     inputs : List[MeasureInput]
@@ -693,7 +675,6 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
         The name of build function to process the built module.
     verbose: int = 1
         Verbosity level. 0 for silent, 1 to output information during program building.
-
     Returns
     -------
     res : List[BuildResult]
@@ -729,10 +710,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
 
 def register_task_input_check_func(func_name, f=None, override=False):
     """Register a function that checks the input buffer map.
-
     The input function should take a list of Tensor wich indicate the Input/output Tensor of a TVM
     subgraph and return a Map from the input Tensor to its buffer name.
-
     Parameters
     ----------
     func_name : Union[Function, str]
@@ -741,11 +720,9 @@ def register_task_input_check_func(func_name, f=None, override=False):
         The check function to be registered.
     override : boolean = False
         Whether to override existing entry.
-
     Examples
     --------
     .. code-block:: python
-
       @auto_scheduler.register_task_input_check_func
       def check_task_input_by_placeholder_name(args : List[Tensor]):
           tensor_input_map = {}
@@ -775,20 +752,17 @@ def register(myf):
     return register
 
 
-def _prepare_input_map(args):
+def prepare_input_map(args):
     """This function deals with special task inputs. Map the input Tensor of a TVM subgraph
     to a specific buffer name in the global buffer map.
-
     Parameters
     ----------
     args : List[Tensor]
         Input/output Tensor of a TVM subgraph.
-
     Returns
     -------
     Dict[Tensor, str] :
         Map from the input Tensor to its buffer name.
-
     Notes
     -----
     The buffer name is specially designed, and these buffer should be provided in
@@ -835,7 +809,7 @@ def _timed_eval_func(
     error_msg = None
     try:
         func = module.load_module(build_res.filename)
-        ctx = ndarray.context(str(inp.task.target), 0)
+        dev = ndarray.device(str(inp.task.target), 0)
         # Limitation:
         # We can not get PackFunction directly in the remote mode as it is wrapped
         # under the std::function. We could lift the restriction later once we fold
@@ -844,7 +818,7 @@ def _timed_eval_func(
         f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
         time_f = func.time_evaluator(
             func.entry_name,
-            ctx,
+            dev,
             number=number,
             repeat=repeat,
             min_repeat_ms=min_repeat_ms,
@@ -861,7 +835,7 @@ def _timed_eval_func(
             random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
             assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"
 
-            tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
+            tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
             args = []
             task_inputs_count = 0
             for arg in build_res.args:
@@ -870,7 +844,7 @@ def _timed_eval_func(
                     if tensor_name in task_input_names:
                         args.append(
                             ndarray.array(
-                                get_task_input_buffer(inp.task.workload_key, tensor_name), ctx
+                                get_task_input_buffer(inp.task.workload_key, tensor_name), dev
                             )
                         )
                         task_inputs_count += 1
@@ -880,14 +854,14 @@ def _timed_eval_func(
                             + "should provide with `SearchTask(..., task_inputs={...})`"
                         )
                 else:
-                    empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
+                    empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev)
                     random_fill(empty_array)
                     args.append(empty_array)
             if task_inputs_count != len(task_input_names):
                 logger.warning(
                     "task_inputs not fully matched, check if there's any unexpected error"
                 )
-            ctx.sync()
+            dev.sync()
             costs = time_f(*args).results
         # pylint: disable=broad-except
         except Exception:
@@ -921,7 +895,6 @@ def local_run(
 ):
     """
     Run function of LocalRunner to test the performance of the input BuildResults.
-
     Parameters
     ----------
     inputs : List[MeasureInput]
@@ -957,7 +930,6 @@ def local_run(
         This is only has effect on CPU task.
     verbose: int = 1
         Verbosity level. 0 for silent, 1 to output information during program measuring.
-
     Returns
     -------
     res : List[MeasureResult]
@@ -1048,7 +1020,7 @@ def _timed_rpc_run(
         remote = request_remote(key, host, port, priority, timeout)
         remote.upload(build_res.filename)
         func = remote.load_module(os.path.split(build_res.filename)[1])
-        ctx = remote.context(str(inp.task.target), 0)
+        dev = remote.device(str(inp.task.target), 0)
         # Limitation:
         # We can not get PackFunction directly in the remote mode as it is wrapped
         # under the std::function. We could lift the restriction later once we fold
@@ -1057,7 +1029,7 @@ def _timed_rpc_run(
         f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
         time_f = func.time_evaluator(
             func.entry_name,
-            ctx,
+            dev,
             number=number,
             repeat=repeat,
             min_repeat_ms=min_repeat_ms,
@@ -1076,7 +1048,7 @@ def _timed_rpc_run(
                 random_fill
             ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"
 
-            tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
+            tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
             args = []
             task_inputs_count = 0
             for arg in build_res.args:
@@ -1085,7 +1057,7 @@ def _timed_rpc_run(
                     if tensor_name in task_input_names:
                         args.append(
                             ndarray.array(
-                                get_task_input_buffer(inp.task.workload_key, tensor_name), ctx
+                                get_task_input_buffer(inp.task.workload_key, tensor_name), dev
                             )
                         )
                         task_inputs_count += 1
@@ -1095,14 +1067,14 @@ def _timed_rpc_run(
                             + "should provide with `SearchTask(..., task_inputs={...})`"
                         )
                 else:
-                    empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
+                    empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev)
                     random_fill(empty_array)
                     args.append(empty_array)
             if task_inputs_count != len(task_input_names):
                 logger.warning(
                     "task_inputs not fully matched, check if there's any unexpected error"
                 )
-            ctx.sync()
+            dev.sync()
             costs = time_f(*args).results
 
             # clean up remote files
@@ -1130,12 +1102,10 @@ def _timed_rpc_run(
 
 def _rpc_run_worker(args):
     """Function to be ran in the RPCRunner thread pool.
-
     Parameters
     ----------
     args : Tuple[MeasureInput, BuildResult, ...]
         Single input and build result plus the rest of the arguments to `rpc_runner_run`.
-
     Returns
     -------
     res : MeasureResult
@@ -1194,7 +1164,6 @@ def rpc_runner_run(
     verbose=1,
 ):
     """Run function of RPCRunner to test the performance of the input BuildResults.
-
     Parameters
     ----------
     inputs : List[MeasureInput]
@@ -1240,7 +1209,6 @@ def rpc_runner_run(
         This is only has effect on CPU task.
     verbose: int = 1
         Verbosity level. 0 for silent, 1 to output information during program measuring.
-
     Returns
     -------
     res : List[MeasureResult]
@@ -1281,4 +1249,4 @@ def rpc_runner_run(
     if verbose >= 1:
         print("")
 
-    return results
+    return results
\ No newline at end of file

From 97eae9f895f6fcfdd910523327219ed7e89b9c2f Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Wed, 14 Apr 2021 10:49:29 +0800
Subject: [PATCH 05/12] Fix Pytorch matmul conversion when given (2-dim, N-dim)
 input pair

---
 python/tvm/relay/frontend/pytorch.py | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index cb9ea6a043f4..2bcbd143a37a 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1580,7 +1580,7 @@ def matmul(self, inputs, input_types):
         b_shape = self.infer_shape_with_prelude(inputs_1)
 
         # When performing a batch matmul, we need to properly handle N-dim shapes.
-        if len(a_shape) > 2 or len(b_shape) > 2:
+        if len(a_shape) > 2 and len(b_shape) > 2:
             # Convert a into a 3 dimensional tensors.
             need_reshape_output = False
             if len(a_shape) != 3:
@@ -1607,6 +1607,12 @@ def matmul(self, inputs, input_types):
                 return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
             return output
 
+        # Reshape a or b into a 2 dimensional tensor
+        if len(a_shape) > 2:
+            inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
+        if len(b_shape) > 2:
+            inputs_1 = _op.reshape(inputs_1, [-1, b_shape[-1]])
+
         # Otherwise a simple dense op will get the job done.
         if len(b_shape) == 1:
             input_1 = _op.expand_dims(inputs_1, 0, 1)
@@ -1618,6 +1624,10 @@ def matmul(self, inputs, input_types):
         if len(b_shape) == 1:
             out = _op.squeeze(out, axis=[-1])
 
+        # Reshape a into a N dimensional tensor when its dim > 2
+        if len(a_shape) > 2:
+            out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
+
         return out
 
     def expand(self, inputs, input_types):

From 67daaed13866051834a42e163ed0ecfe2cbf0009 Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Wed, 14 Apr 2021 10:56:39 +0800
Subject: [PATCH 06/12] update measure.py

---
 python/tvm/auto_scheduler/measure.py | 39 +++++++++++++++++++++++++++-
 1 file changed, 38 insertions(+), 1 deletion(-)

diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 9f0b97749a28..83f1bcec7ebc 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -17,13 +17,17 @@
 
 """
 Distributed measurement infrastructure to measure the runtime costs of tensor programs.
+
 These functions are responsible for building the tvm module, uploading it to
 remote devices, recording the running time costs, and checking the correctness of the output.
+
 We separate the measurement into two steps: build and run.
 A builder builds the executable binary files and a runner runs the binary files to
 get the measurement results. The flow of data structures is
+
   .               `ProgramBuilder`                 `ProgramRunner`
   `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
+
 We implement these in python to utilize python's multiprocessing and error handling.
 """
 
@@ -95,6 +99,7 @@ def callback_func(policy, inputs, results):
 
     def callback(self, policy, inputs, results):
         """The callback function.
+
         Parameters
         ----------
         policy: auto_scheduler.search_policy.SearchPolicy
@@ -110,6 +115,7 @@ def callback(self, policy, inputs, results):
 @tvm._ffi.register_object("auto_scheduler.MeasureInput")
 class MeasureInput(Object):
     """Store the input of a measurement.
+
     Parameters
     ----------
     task : SearchTask
@@ -125,6 +131,7 @@ def __init__(self, task, state):
     def serialize(self):
         """Custom serialization to workaround MeasureInput not exposing all its
         members to the TVM ffi interface.
+
         Note that we do not implement __getstate__ as it does not seem to work
         with initialization of the workload registry (maybe because of
         initialization order?).
@@ -144,6 +151,7 @@ def deserialize(data):
 @tvm._ffi.register_object("auto_scheduler.BuildResult")
 class BuildResult(Object):
     """Store the result of a build.
+
     Parameters
     ----------
     filename : Optional[str]
@@ -170,6 +178,7 @@ def __init__(self, filename, args, error_no, error_msg, time_cost):
 @tvm._ffi.register_object("auto_scheduler.MeasureResult")
 class MeasureResult(Object):
     """Store the results of a measurement.
+
     Parameters
     ----------
     costs : List[float]
@@ -197,12 +206,14 @@ def recover_measure_input(inp, rebuild_state=False):
     Recover a deserialized MeasureInput by rebuilding the missing fields.
     1. Rebuid the compute_dag in inp.task
     2. (Optional) Rebuild the stages in inp.state
+
     Parameters
     ----------
     inp: MeasureInput
         The deserialized MeasureInput
     rebuild_state: bool = False
         Whether rebuild the stages in MeasureInput.State
+
     Returns
     -------
     new_input: MeasureInput
@@ -237,12 +248,14 @@ class ProgramBuilder(Object):
 
     def build(self, measure_inputs, verbose=1):
         """Build programs and return results.
+
         Parameters
         ----------
         measure_inputs : List[MeasureInput]
             A List of MeasureInput.
         verbose: int = 1
             Verbosity level. 0 for silent, 1 to output information during program building.
+
         Returns
         -------
         res : List[BuildResult]
@@ -256,6 +269,7 @@ class ProgramRunner(Object):
 
     def run(self, measure_inputs, build_results, verbose=1):
         """Run measurement and return results.
+
         Parameters
         ----------
         measure_inputs : List[MeasureInput]
@@ -264,6 +278,7 @@ def run(self, measure_inputs, build_results, verbose=1):
             A List of BuildResult to be ran.
         verbose: int = 1
             Verbosity level. 0 for silent, 1 to output information during program running.
+
         Returns
         -------
         res : List[MeasureResult]
@@ -276,6 +291,7 @@ class ProgramMeasurer(Object):
     """
     Measurer that measures the time costs of tvm programs
     This class combines ProgramBuilder and ProgramRunner, and provides a simpler API.
+
     Parameters
     ----------
     builder : ProgramBuilder
@@ -300,6 +316,7 @@ def __init__(self, builder, runner, callbacks, verbose, max_continuous_error=Non
 @tvm._ffi.register_object("auto_scheduler.LocalBuilder")
 class LocalBuilder(ProgramBuilder):
     """LocalBuilder use local CPU cores to build programs in parallel.
+
     Parameters
     ----------
     timeout : int = 15
@@ -334,6 +351,7 @@ def __init__(self, timeout=15, n_parallel=multiprocessing.cpu_count(), build_fun
 @tvm._ffi.register_object("auto_scheduler.LocalRunner")
 class LocalRunner(ProgramRunner):
     """LocalRunner that uses local CPU/GPU to measures the time cost of programs.
+
     Parameters
     ----------
     timeout : int = 10
@@ -394,6 +412,7 @@ class RPCRunner(ProgramRunner):
     """RPCRunner that uses RPC call to measures the time cost of programs on remote devices.
     Or sometime we may need to use RPC even in local running to insulate the thread environment.
     (e.g. running CUDA programs)
+
     Parameters
     ----------
     key : str
@@ -478,6 +497,7 @@ def __init__(
 class LocalRPCMeasureContext:
     """A context wrapper for running RPCRunner locally.
     This will launch a local RPC Tracker and local RPC Server.
+
     Parameters
     ----------
     priority : int = 1
@@ -630,10 +650,12 @@ def _timed_func(inp_serialized, build_func, verbose):
 def local_build_worker(args):
     """
     Build function of LocalBuilder to be ran in the Builder thread pool.
+
     Parameters
     ----------
     args: Tuple[MeasureInput, str, int, int]
         inputs, build-func, time, verbose args passed to local_builder_build
+
     Returns
     -------
     res : BuildResult
@@ -662,6 +684,7 @@ def local_build_worker(args):
 def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbose=1):
     """
     Build function of LocalBuilder to build the MeasureInputs to runnable modules.
+
     Parameters
     ----------
     inputs : List[MeasureInput]
@@ -675,6 +698,7 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
         The name of build function to process the built module.
     verbose: int = 1
         Verbosity level. 0 for silent, 1 to output information during program building.
+
     Returns
     -------
     res : List[BuildResult]
@@ -710,8 +734,10 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
 
 def register_task_input_check_func(func_name, f=None, override=False):
     """Register a function that checks the input buffer map.
+
     The input function should take a list of Tensor wich indicate the Input/output Tensor of a TVM
     subgraph and return a Map from the input Tensor to its buffer name.
+
     Parameters
     ----------
     func_name : Union[Function, str]
@@ -720,9 +746,11 @@ def register_task_input_check_func(func_name, f=None, override=False):
         The check function to be registered.
     override : boolean = False
         Whether to override existing entry.
+
     Examples
     --------
     .. code-block:: python
+
       @auto_scheduler.register_task_input_check_func
       def check_task_input_by_placeholder_name(args : List[Tensor]):
           tensor_input_map = {}
@@ -755,14 +783,17 @@ def register(myf):
 def prepare_input_map(args):
     """This function deals with special task inputs. Map the input Tensor of a TVM subgraph
     to a specific buffer name in the global buffer map.
+
     Parameters
     ----------
     args : List[Tensor]
         Input/output Tensor of a TVM subgraph.
+
     Returns
     -------
     Dict[Tensor, str] :
         Map from the input Tensor to its buffer name.
+
     Notes
     -----
     The buffer name is specially designed, and these buffer should be provided in
@@ -895,6 +926,7 @@ def local_run(
 ):
     """
     Run function of LocalRunner to test the performance of the input BuildResults.
+
     Parameters
     ----------
     inputs : List[MeasureInput]
@@ -930,6 +962,7 @@ def local_run(
         This is only has effect on CPU task.
     verbose: int = 1
         Verbosity level. 0 for silent, 1 to output information during program measuring.
+
     Returns
     -------
     res : List[MeasureResult]
@@ -1102,10 +1135,12 @@ def _timed_rpc_run(
 
 def _rpc_run_worker(args):
     """Function to be ran in the RPCRunner thread pool.
+
     Parameters
     ----------
     args : Tuple[MeasureInput, BuildResult, ...]
         Single input and build result plus the rest of the arguments to `rpc_runner_run`.
+
     Returns
     -------
     res : MeasureResult
@@ -1164,6 +1199,7 @@ def rpc_runner_run(
     verbose=1,
 ):
     """Run function of RPCRunner to test the performance of the input BuildResults.
+
     Parameters
     ----------
     inputs : List[MeasureInput]
@@ -1209,6 +1245,7 @@ def rpc_runner_run(
         This is only has effect on CPU task.
     verbose: int = 1
         Verbosity level. 0 for silent, 1 to output information during program measuring.
+
     Returns
     -------
     res : List[MeasureResult]
@@ -1249,4 +1286,4 @@ def rpc_runner_run(
     if verbose >= 1:
         print("")
 
-    return results
\ No newline at end of file
+    return results

From 2dc1adc9226913700cbf4201e0eeb6a82a893023 Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Wed, 14 Apr 2021 10:58:43 +0800
Subject: [PATCH 07/12] Lint fix

---
 python/tvm/relay/frontend/pytorch.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 2bcbd143a37a..4d3a66086a5d 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1624,7 +1624,7 @@ def matmul(self, inputs, input_types):
         if len(b_shape) == 1:
             out = _op.squeeze(out, axis=[-1])
 
-        # Reshape a into a N dimensional tensor when its dim > 2
+        # Reshape output into a N dimensional tensor when a dim > 2
         if len(a_shape) > 2:
             out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
 

From 892ca94b01350030da28e3ed5eb1602cfc02d331 Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Wed, 14 Apr 2021 14:56:40 +0800
Subject: [PATCH 08/12] fix bug && add ut for pytorch matmul

---
 python/tvm/relay/frontend/pytorch.py          | 20 +++++++------
 tests/python/frontend/pytorch/test_forward.py | 28 ++++++++++++++++---
 2 files changed, 35 insertions(+), 13 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 4d3a66086a5d..b89ea6ee5d51 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1606,17 +1606,16 @@ def matmul(self, inputs, input_types):
             if need_reshape_output:
                 return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
             return output
-
-        # Reshape a or b into a 2 dimensional tensor
-        if len(a_shape) > 2:
+        elif len(a_shape) > 2:
             inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
-        if len(b_shape) > 2:
-            inputs_1 = _op.reshape(inputs_1, [-1, b_shape[-1]])
 
-        # Otherwise a simple dense op will get the job done.
-        if len(b_shape) == 1:
+        if len(b_shape) > 2:
+            trans_axes = list(range(len(b_shape)))
+            trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+            input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]])
+        elif len(b_shape) == 1:
             input_1 = _op.expand_dims(inputs_1, 0, 1)
-        else:
+        elif len(b_shape) == 2:
             input_1 = _op.transpose(inputs_1, axes=(1, 0))
 
         out = _op.nn.dense(inputs_0, input_1)
@@ -1624,9 +1623,12 @@ def matmul(self, inputs, input_types):
         if len(b_shape) == 1:
             out = _op.squeeze(out, axis=[-1])
 
-        # Reshape output into a N dimensional tensor when a dim > 2
+        # Reshape output into a N dimensional tensor when a or b dim > 2
         if len(a_shape) > 2:
             out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
+        elif len(b_shape) > 2:
+            out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
+            out = _op.reshape(_op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]])
 
         return out
 
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 9ec52987c354..c539ff684aab 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -162,7 +162,7 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
                 return est
 
 
-def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5):
+def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5, expected_ops=[]):
     """Assert that the output of a compiled model matches with that of its
     baseline."""
     if isinstance(model_name, str):
@@ -219,6 +219,21 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at
 
                 assert_shapes_match(baseline_output, compiled_output)
                 tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)
+
+    if len(expected_ops) != 0:
+        found_op = dict.fromkeys(expected_ops, False)
+        def visit(op):
+            if isinstance(op, tvm.ir.op.Op):
+                if op.name in expected_ops:
+                    found_op[op.name] = True
+                
+        tvm.relay.analysis.post_order_visit(mod['main'].body, visit)
+
+        for op_name, is_found in enumerate(found_op):
+            if not is_found:
+                msg = "TVM Relay do not contain expected op [{}]"
+                raise AssertionError(msg.format(op_name))
+
     del model_name
     del baseline_model
     torch.cuda.empty_cache()
@@ -3304,17 +3319,22 @@ def forward(self, *args):
     # matrix x matrix
     tensor1 = torch.randn(10, 4)
     tensor2 = torch.randn(4, 10)
-    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.dense'])
 
     # batched matrix x batched matrix
     tensor1 = torch.randn(10, 3, 4)
     tensor2 = torch.randn(10, 4, 5)
-    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.batch_matmul'])
 
     # batched matrix x broadcasted matrix
     tensor1 = torch.randn(10, 3, 4)
     tensor2 = torch.randn(4, 5)
-    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.dense'])
+
+    # broadcasted matrix x batched matrix
+    tensor1 = torch.randn(10, 4)
+    tensor2 = torch.randn(3, 4, 5)
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.dense'])
 
     # batched matrix x batched matrix
     tensor1 = torch.randn(1, 12, 14, 64)

From df9caa39f96c8e6d075f4367ba11bfd27cd948c4 Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Wed, 14 Apr 2021 15:17:49 +0800
Subject: [PATCH 09/12] update ut

---
 tests/python/frontend/pytorch/test_forward.py | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index c539ff684aab..3cb40738d5cd 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -220,19 +220,17 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at
                 assert_shapes_match(baseline_output, compiled_output)
                 tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)
 
-    if len(expected_ops) != 0:
-        found_op = dict.fromkeys(expected_ops, False)
+    if expected_ops:
         def visit(op):
             if isinstance(op, tvm.ir.op.Op):
                 if op.name in expected_ops:
-                    found_op[op.name] = True
+                    expected_ops.remove(op.name)
                 
         tvm.relay.analysis.post_order_visit(mod['main'].body, visit)
 
-        for op_name, is_found in enumerate(found_op):
-            if not is_found:
-                msg = "TVM Relay do not contain expected op [{}]"
-                raise AssertionError(msg.format(op_name))
+        if expected_ops:
+            msg = "TVM Relay do not contain expected ops {}"
+            raise AssertionError(msg.format(expected_ops))
 
     del model_name
     del baseline_model

From f2456bc04396fcca5bb089daffd840d292ffc11b Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Wed, 14 Apr 2021 15:18:51 +0800
Subject: [PATCH 10/12] Lint fix

---
 tests/python/frontend/pytorch/test_forward.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 3cb40738d5cd..1fc2483afdad 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -225,7 +225,7 @@ def visit(op):
             if isinstance(op, tvm.ir.op.Op):
                 if op.name in expected_ops:
                     expected_ops.remove(op.name)
-                
+
         tvm.relay.analysis.post_order_visit(mod['main'].body, visit)
 
         if expected_ops:

From 0bae9dcd761b694872cc69bc0748f989f00833fa Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Wed, 14 Apr 2021 15:20:48 +0800
Subject: [PATCH 11/12] update commit

---
 python/tvm/relay/frontend/pytorch.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index b89ea6ee5d51..7a42adf039a5 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1613,10 +1613,10 @@ def matmul(self, inputs, input_types):
             trans_axes = list(range(len(b_shape)))
             trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
             input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]])
-        elif len(b_shape) == 1:
-            input_1 = _op.expand_dims(inputs_1, 0, 1)
         elif len(b_shape) == 2:
             input_1 = _op.transpose(inputs_1, axes=(1, 0))
+        elif len(b_shape) == 1:
+            input_1 = _op.expand_dims(inputs_1, 0, 1)
 
         out = _op.nn.dense(inputs_0, input_1)
 

From 51e19b51bafe7c1f5cfaab53b6fcc32c36974379 Mon Sep 17 00:00:00 2001
From: yuchaoli <xiamenlyc@163.com>
Date: Wed, 14 Apr 2021 17:44:47 +0800
Subject: [PATCH 12/12] Lint fix

---
 python/tvm/relay/frontend/pytorch.py          |  4 +++-
 tests/python/frontend/pytorch/test_forward.py | 17 +++++++++++------
 2 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 7a42adf039a5..a31c44a369f9 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1628,7 +1628,9 @@ def matmul(self, inputs, input_types):
             out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
         elif len(b_shape) > 2:
             out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
-            out = _op.reshape(_op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]])
+            out = _op.reshape(
+                _op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]]
+            )
 
         return out
 
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 1fc2483afdad..bff5bb60e24f 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -162,7 +162,9 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
                 return est
 
 
-def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5, expected_ops=[]):
+def verify_model(
+    model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5, expected_ops=[]
+):
     """Assert that the output of a compiled model matches with that of its
     baseline."""
     if isinstance(model_name, str):
@@ -221,12 +223,13 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at
                 tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)
 
     if expected_ops:
+
         def visit(op):
             if isinstance(op, tvm.ir.op.Op):
                 if op.name in expected_ops:
                     expected_ops.remove(op.name)
 
-        tvm.relay.analysis.post_order_visit(mod['main'].body, visit)
+        tvm.relay.analysis.post_order_visit(mod["main"].body, visit)
 
         if expected_ops:
             msg = "TVM Relay do not contain expected ops {}"
@@ -3317,22 +3320,24 @@ def forward(self, *args):
     # matrix x matrix
     tensor1 = torch.randn(10, 4)
     tensor2 = torch.randn(4, 10)
-    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.dense'])
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
 
     # batched matrix x batched matrix
     tensor1 = torch.randn(10, 3, 4)
     tensor2 = torch.randn(10, 4, 5)
-    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.batch_matmul'])
+    verify_model(
+        MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
+    )
 
     # batched matrix x broadcasted matrix
     tensor1 = torch.randn(10, 3, 4)
     tensor2 = torch.randn(4, 5)
-    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.dense'])
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
 
     # broadcasted matrix x batched matrix
     tensor1 = torch.randn(10, 4)
     tensor2 = torch.randn(3, 4, 5)
-    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.dense'])
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
 
     # batched matrix x batched matrix
     tensor1 = torch.randn(1, 12, 14, 64)