diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index d02dcff3bba0..5d1a79bf18a6 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -775,7 +775,7 @@ def register(myf): return register -def _prepare_input_map(args): +def prepare_input_map(args): """This function deals with special task inputs. Map the input Tensor of a TVM subgraph to a specific buffer name in the global buffer map. @@ -861,7 +861,7 @@ def _timed_eval_func( random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" - tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} + tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {} args = [] task_inputs_count = 0 for arg in build_res.args: @@ -1076,7 +1076,7 @@ def _timed_rpc_run( random_fill ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices" - tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} + tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {} args = [] task_inputs_count = 0 for arg in build_res.args: diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 6cce30f2f559..e931fc6e298d 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -142,6 +142,12 @@ def extract_tasks( # When auto scheduler is used in end to end network, try to apply layout rewrite # to improve the overall performance layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True), + task_inputs=( + env.wkl_key_to_input_names[wkl_key] + if wkl_key in env.wkl_key_to_input_names + else None + ), + task_inputs_save_to_file=True, ) ) weights.append(weight) @@ -166,6 +172,7 @@ def __init__(self, tracing_mode): self.tracing_mode = tracing_mode self.relay_disable_build_cache = "false" self.wkl_key_to_weight = {} + self.wkl_key_to_input_names = {} def __enter__(self): TracingEnvironment.current = self @@ -175,17 +182,30 @@ def __exit__(self, exc_type, exc_val, exc_tb): TracingEnvironment.current = None def add_workload_key(self, workload_key): - """Add the workload key of a search task + """Add the workload key of a search task. Parameters ---------- workload_key: str - The workload key of a task + The workload key of a task. """ if workload_key not in self.wkl_key_to_weight: self.wkl_key_to_weight[workload_key] = 0 self.wkl_key_to_weight[workload_key] += 1 + def add_workload_input_names(self, workload_key, input_names): + """Add special task inputs to this workload. + + Parameters + ---------- + workload_key : str + The workload key of a task. + + input_names : List[str] + A list of input names. + """ + self.wkl_key_to_input_names[workload_key] = input_names + @tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite") def enter_layout_rewrite(): @@ -274,6 +294,9 @@ def auto_schedule_topi(outs): None in the tracing mode so that the fallback topi schedule will be used. """ # pylint: disable=import-outside-toplevel + from tvm.auto_scheduler.measure import ( + prepare_input_map, + ) # lazily import to avoid recursive dependency io_tensors, has_layout_free, has_complex_op = traverse_to_get_io_tensors(outs) if not io_tensors: # The compute includes dynamic shapes which are not supported yet. @@ -305,6 +328,9 @@ def auto_schedule_topi(outs): # in the task extraction mode if has_complex_op or env.tracing_mode == TracingMode.EXTRACT_TASK: env.add_workload_key(key) + input_map = prepare_input_map(io_tensors) + if input_map: + env.add_workload_input_names(key, list(input_map.values())) elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: # in prepare_layout_rewrite mode if ( diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 57e239cf79e8..c5c2b5b44451 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -299,13 +299,18 @@ def get_task_input_buffer(workload_key, input_name): TASK_INPUT_BUFFER_TABLE[workload_key] = {} input_table = TASK_INPUT_BUFFER_TABLE[workload_key] - if input_name not in input_table.keys(): + if input_name not in input_table: # Try to load buffer data from local file tensor_from_file = _try_load_buffer_from_file(input_name) if tensor_from_file: input_table[input_name] = tensor_from_file - if input_name in input_table.keys(): + # Then check for the default table, the input names extracted from a relay model will be + # stored here for we're not able to get the workload_key at that time + if input_name not in input_table: + input_table = TASK_INPUT_BUFFER_TABLE["default"] + + if input_name in input_table: return input_table[input_name] raise ValueError( diff --git a/python/tvm/relay/analysis/sparse_dense.py b/python/tvm/relay/analysis/sparse_dense.py index d521748f2311..23929f45917d 100644 --- a/python/tvm/relay/analysis/sparse_dense.py +++ b/python/tvm/relay/analysis/sparse_dense.py @@ -73,6 +73,12 @@ def process_params(expr, params, block_size, sparsity_threshold): ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]] return names of qualified dense weight and the shape in BSR format """ + + # pylint: disable=import-outside-toplevel + from tvm.auto_scheduler.search_task import ( + register_task_input_buffer, + ) # lazily import to avoid recursive dependency + memo = SparseAnalysisResult(weight_name=[], weight_shape=[]) weight_names = _search_dense_op_weight(expr) for name in weight_names: @@ -92,6 +98,23 @@ def process_params(expr, params, block_size, sparsity_threshold): params[name + ".data"] = tvm.nd.array(sparse_weight.data) params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) + + prefix = "sparse_dense_bsr_%d_%d_%d_%d_%.2f_" % ( + w_np.shape[0], + w_np.shape[1], + block_size[0], + block_size[1], + 1 - sparsity, + ) + register_task_input_buffer( + "default", prefix + "W_data", tvm.runtime.ndarray.array(sparse_weight.data) + ) + register_task_input_buffer( + "default", prefix + "W_indices", tvm.runtime.ndarray.array(sparse_weight.indices) + ) + register_task_input_buffer( + "default", prefix + "W_indptr", tvm.runtime.ndarray.array(sparse_weight.indptr) + ) ret = SparseAnalysisResult( weight_name=tvm.runtime.convert(memo.weight_name), weight_shape=tvm.runtime.convert(memo.weight_shape), diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 1b593ad8dea3..9deb8cd35b33 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1872,7 +1872,7 @@ def convert_fully_connected(self, op): out_dtype="int32", ) else: - out = _op.nn.dense(in_expr, weight_expr) + out = _op.nn.dense(in_expr, weight_expr, units=weight_shape[0]) # if we have bias if len(input_tensors) == 3: diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 756110624aa1..f5737d087fc7 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -426,7 +426,7 @@ def _process_inputs(input_tensors, m, n, prefix_init): density *= i density /= k * n density = density.value - sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % (prefix_init, m, n, k, bs_r, bs_c, density) + sparse_prefix = "%s_%d_%d_%d_%d_%.2f_" % (prefix_init, n, k, bs_r, bs_c, density) visited = set() diff --git a/python/tvm/topi/sparse/utils.py b/python/tvm/topi/sparse/utils.py new file mode 100644 index 000000000000..43bc6e021429 --- /dev/null +++ b/python/tvm/topi/sparse/utils.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Some utils for Sparse operation.""" +import tvm +from tvm import relay + + +def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype): + """Generate a random sparse matrix in bsr format. + + Returns + ------- + scipy.sparse.bsr_matrix + """ + # pylint: disable=import-outside-toplevel + import numpy as np + import itertools + import scipy.sparse as sp + + y = np.zeros((m, n), dtype=dtype) + assert m % bs_r == 0 + assert n % bs_c == 0 + nnz = int(density * m * n) + num_blocks = int(nnz / (bs_r * bs_c)) + 1 + candidate_blocks = np.asarray(list(itertools.product(range(0, m, bs_r), range(0, n, bs_c)))) + assert candidate_blocks.shape[0] == m // bs_r * n // bs_c + chosen_blocks = candidate_blocks[ + np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) + ] + # pylint: disable=invalid-name + for (r, c) in chosen_blocks: + y[r : r + bs_r, c : c + bs_c] = np.random.randn(bs_r, bs_c) + s = sp.bsr_matrix(y, blocksize=(bs_r, bs_c)) + assert s.data.shape == (num_blocks, bs_r, bs_c) + assert s.indices.shape == (num_blocks,) + assert s.indptr.shape == (m // bs_r + 1,) + return s + + +def random_sparse_dense_params(func, params, bs_r, bs_c, density): + """Replace the dense parameters with random sparse parameters. Mainly used for testing. + + Parameters + ---------- + func : tvm.relay.Expr + Expr will be optimized to sparse operation. + params : Dict[Srting, tvm.nd.array] + Parameters of the Expr. + bs_r : int + The row of BSR matrix block. + bs_c : int + The column of BSR matrix block. + density : float + The density of the random sparse parameters. + + Returns + ------- + Dict[Srting, tvm.nd.array] + The generated random parameters. + """ + + def deepcopy(param_dic): + ret = {} + for k, v in param_dic.items(): + ret[k] = tvm.nd.array(v.asnumpy()) + return ret + + new_params = deepcopy(params) + dense_weight_names = relay.analysis.sparse_dense._search_dense_op_weight(func) + for item in dense_weight_names: + name = str(item) + shape = new_params[name].shape + if shape[0] % bs_r == 0 and shape[1] % bs_c == 0: + new_w = random_bsr_matrix(shape[0], shape[1], bs_r, bs_c, density, "float32").todense() + new_params[name] = tvm.nd.array(new_w) + return new_params + + +def convert_model_dense_to_sparse(mod, params, random_params=False, bs_r=1, bs_c=1, sparsity=0.85): + """Convert a dense model to sparse model. + + Parameters + ---------- + mod : tvm.Module + The dense model. + params : Dict[Srting, tvm.nd.array] + Parameters of the dense model. + random_params : Bool = False + True to replace the parameters of the dense model with some random sparse tensors. + This is mainly used for testing. + bs_r : int + The row of BSR matrix block. + bs_c : int + The column of BSR matrix block. + sparsity : float + The sparsity of the random sparse parameters. + + Returns + ------- + tvm.Module + The updated sparse model. + Dict[Srting, tvm.nd.array] + The updated parameters. + """ + mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) + if random_params: + # Manually replace the parameters of dense model to sparse tensors + params = random_sparse_dense_params(mod, params, bs_r=bs_r, bs_c=bs_c, density=1 - sparsity) + # Currently we only support to conver dense matmul to sparse dense matmul + mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, bs_c), sparsity_threshold=0.8) + + return tvm.IRModule.from_expr(mod), params diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index c4add79450e9..5a8407eb8675 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -17,7 +17,9 @@ """ Auto-scheduling a Neural Network for ARM CPU ============================================= -**Author**: `Thierry Moreau >`_ +**Author**: `Thierry Moreau _`, \ + `Lianmin Zheng _`, \ + `Chengfan Jia `_ Auto-tuning for specific devices and workloads is critical for getting the best performance. This is a tutorial on how to tune a whole neural @@ -45,9 +47,11 @@ """ import numpy as np +import os import tvm from tvm import relay, auto_scheduler +from tvm.relay import data_dep_optimization as ddo import tvm.relay.testing from tvm.contrib import graph_runtime from tvm.contrib.utils import tempdir @@ -67,7 +71,7 @@ # You can use :ref:`ConvertLayout ` pass to do the layout conversion in TVM. -def get_network(name, batch_size, layout="NHWC", dtype="float32"): +def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=False): """Get the symbol definition and random weight of a network""" # auto-scheduler prefers NHWC layout @@ -127,6 +131,17 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs ) mod = tvm.IRModule.from_expr(net) + elif name == "mlp": + mod, params = relay.testing.mlp.get_workload( + batch_size=batch_size, dtype=dtype, image_shape=image_shape, num_classes=1000 + ) + else: + raise ValueError("Network not found.") + + if use_sparse: + from tvm.topi.sparse.utils import convert_model_dense_to_sparse + + mod, params = convert_model_dense_to_sparse(mod, params, random_params=True) return mod, params, input_shape, output_shape @@ -217,8 +232,10 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): # because we're sharing x86 op strategy. target = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+neon") -# Also replace this with the device key in your tracker +# Also replace this with the device key, rpc host and rpc port in your tracker device_key = "rasp4b-64" +rpc_host = "0.0.0.0" +rpc_port = 9191 # Set this to True if you use ndk tools for cross compiling # And also set the environment variable below to point to the cross compiler @@ -227,6 +244,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): #### TUNING OPTION #### network = "mobilenet" +use_sparse = False batch_size = 1 layout = "NHWC" dtype = "float32" @@ -244,8 +262,11 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): # The task scheduler will just optimize this objective. # Extract tasks from the network +print("Get model...") +mod, params, input_shape, output_shape = get_network( + network, batch_size, layout, dtype=dtype, use_sparse=use_sparse +) print("Extract tasks...") -mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) for idx, task in enumerate(tasks): @@ -280,10 +301,11 @@ def tune_and_evaluate(): tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( num_measure_trials=200, # change this to 20000 to achieve the best performance + builder=auto_scheduler.LocalBuilder(build_func="ndk" if use_ndk else "default"), runner=auto_scheduler.RPCRunner( device_key, - host="0.0.0.0", - port=9191, + host=rpc_host, + port=rpc_port, timeout=30, repeat=1, min_repeat_ms=200, @@ -315,7 +337,7 @@ def tune_and_evaluate(): # Upload module to device print("Upload...") - remote = auto_scheduler.utils.request_remote(device_key, "0.0.0.0", 9191, timeout=10000) + remote = auto_scheduler.utils.request_remote(device_key, rpc_host, rpc_port, timeout=10000) remote.upload(tmp.relpath(filename)) rlib = remote.load_module(filename) diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 8526abbbe6ca..2839db8646d0 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -17,7 +17,8 @@ """ Auto-scheduling a Neural Network for x86 CPU ============================================ -**Author**: `Lianmin Zheng `_ +**Author**: `Lianmin Zheng `_, \ + `Chengfan Jia `_ Auto-tuning for specific devices and workloads is critical for getting the best performance. This is a tutorial on how to tune a whole neural @@ -48,6 +49,7 @@ import tvm from tvm import relay, auto_scheduler +from tvm.relay import data_dep_optimization as ddo import tvm.relay.testing from tvm.contrib import graph_runtime @@ -66,7 +68,7 @@ # You can use :ref:`ConvertLayout ` pass to do the layout conversion in TVM. -def get_network(name, batch_size, layout="NHWC", dtype="float32"): +def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=False): """Get the symbol definition and random weight of a network""" # auto-scheduler prefers NHWC layout @@ -126,6 +128,17 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs ) mod = tvm.IRModule.from_expr(net) + elif name == "mlp": + mod, params = relay.testing.mlp.get_workload( + batch_size=batch_size, dtype=dtype, image_shape=image_shape, num_classes=1000 + ) + else: + raise ValueError("Network not found.") + + if use_sparse: + from tvm.topi.sparse.utils import convert_model_dense_to_sparse + + mod, params = convert_model_dense_to_sparse(mod, params, random_params=True) return mod, params, input_shape, output_shape @@ -134,6 +147,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): # If the target machine supports avx512 instructions, replace the # "llvm -mcpu=core-avx2" with "llvm -mcpu=skylake-avx512" network = "resnet-50" +use_sparse = False batch_size = 1 layout = "NHWC" target = tvm.target.Target("llvm -mcpu=core-avx2") @@ -152,8 +166,11 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): # The task scheduler will just optimize this objective. # Extract tasks from the network +print("Get model...") +mod, params, input_shape, output_shape = get_network( + network, batch_size, layout, dtype=dtype, use_sparse=use_sparse +) print("Extract tasks...") -mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) for idx, task in enumerate(tasks): diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index ced416f6c500..ad3646dfc19d 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -36,15 +36,13 @@ """ import os -import itertools import numpy as np import tvm from tvm import te, auto_scheduler, runtime, topi from tvm.auto_scheduler import _ffi_api from tvm.topi.utils import get_const_tuple - -import scipy.sparse as sp +from tvm.topi.sparse.utils import random_bsr_matrix ###################################################################### # Define the computation @@ -53,29 +51,6 @@ # The function should return the list of input/output tensors. # From these tensors, the auto-scheduler can get the whole computational graph. -# We use this function to generate a random bsr matrix -def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): - import itertools - - Y = np.zeros((M, N), dtype=dtype) - assert M % BS_R == 0 - assert N % BS_C == 0 - nnz = int(density * M * N) - num_blocks = int(nnz / (BS_R * BS_C)) + 1 - candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) - assert candidate_blocks.shape[0] == M // BS_R * N // BS_C - chosen_blocks = candidate_blocks[ - np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) - ] - for i in range(len(chosen_blocks)): - r, c = chosen_blocks[i] - Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C) - s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) - assert s.data.shape == (num_blocks, BS_R, BS_C) - assert s.indices.shape == (num_blocks,) - assert s.indptr.shape == (M // BS_R + 1,) - return s - @auto_scheduler.register_workload def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): @@ -104,7 +79,9 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): # See the `tvm.auto_scheduler.measure.py` for more details. # Define the basic shapes of this sparse computation -M = K = N = 512 +M = 128 +K = 256 +N = 512 BS_R = 16 BS_C = 1 density = 0.6 @@ -131,7 +108,7 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): target = tvm.target.Target("llvm") # Register the sparse data to task inputs -prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) +prefix = "sparse_dense_bsr_%d_%d_%d_%d_%.2f_" % (N, K, BS_R, BS_C, density) task = tvm.auto_scheduler.SearchTask( func=sparse_dense, args=(M, N, K, W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, "float32"),