Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Autoscheduler][Sparse] Add sparse dense end to end model tuning support for x86/arm cpu & Some bug fix #7635

Merged
merged 27 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def register(myf):
return register


def _prepare_input_map(args):
def prepare_input_map(args):
"""This function deals with special task inputs. Map the input Tensor of a TVM subgraph
to a specific buffer name in the global buffer map.

Expand Down Expand Up @@ -861,7 +861,7 @@ def _timed_eval_func(
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"

tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
Expand Down Expand Up @@ -1076,7 +1076,7 @@ def _timed_rpc_run(
random_fill
), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"

tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
Expand Down
29 changes: 27 additions & 2 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def extract_tasks(
# When auto scheduler is used in end to end network, try to apply layout rewrite
# to improve the overall performance
layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True),
task_inputs=(
env.wkl_key_to_input_names[wkl_key]
if wkl_key in env.wkl_key_to_input_names
else None
),
)
)
weights.append(weight)
Expand All @@ -166,6 +171,7 @@ def __init__(self, tracing_mode):
self.tracing_mode = tracing_mode
self.relay_disable_build_cache = "false"
self.wkl_key_to_weight = {}
self.wkl_key_to_input_names = {}

def __enter__(self):
TracingEnvironment.current = self
Expand All @@ -175,17 +181,30 @@ def __exit__(self, exc_type, exc_val, exc_tb):
TracingEnvironment.current = None

def add_workload_key(self, workload_key):
"""Add the workload key of a search task
"""Add the workload key of a search task.

Parameters
----------
workload_key: str
The workload key of a task
The workload key of a task.
"""
if workload_key not in self.wkl_key_to_weight:
self.wkl_key_to_weight[workload_key] = 0
self.wkl_key_to_weight[workload_key] += 1

def add_workload_input_names(self, workload_key, input_names):
"""Add special task inputs to this workload.

Parameters
----------
workload_key : str
The workload key of a task.

input_names : List[str]
A list of input names.
"""
self.wkl_key_to_input_names[workload_key] = input_names


@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
def enter_layout_rewrite():
Expand Down Expand Up @@ -274,6 +293,9 @@ def auto_schedule_topi(outs):
None in the tracing mode so that the fallback topi schedule will be used.
"""
# pylint: disable=import-outside-toplevel
from tvm.auto_scheduler.measure import (
prepare_input_map,
) # lazily import to avoid recursive dependency

io_tensors, has_layout_free, has_complex_op = traverse_to_get_io_tensors(outs)
if not io_tensors: # The compute includes dynamic shapes which are not supported yet.
Expand Down Expand Up @@ -305,6 +327,9 @@ def auto_schedule_topi(outs):
# in the task extraction mode
if has_complex_op or env.tracing_mode == TracingMode.EXTRACT_TASK:
env.add_workload_key(key)
input_map = prepare_input_map(io_tensors)
if input_map:
env.add_workload_input_names(key, list(input_map.values()))
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
# in prepare_layout_rewrite mode
if (
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,18 @@ def get_task_input_buffer(workload_key, input_name):
TASK_INPUT_BUFFER_TABLE[workload_key] = {}
input_table = TASK_INPUT_BUFFER_TABLE[workload_key]

if input_name not in input_table.keys():
if input_name not in input_table:
# Try to load buffer data from local file
tensor_from_file = _try_load_buffer_from_file(input_name)
if tensor_from_file:
input_table[input_name] = tensor_from_file

if input_name in input_table.keys():
# Then check for the default table, the input names extracted from a relay model will be
# stored here for we're not able to get the workload_key at that time
if input_name not in input_table:
input_table = TASK_INPUT_BUFFER_TABLE["default"]

if input_name in input_table:
return input_table[input_name]

raise ValueError(
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/analysis/sparse_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def process_params(expr, params, block_size, sparsity_threshold):
ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]]
return names of qualified dense weight and the shape in BSR format
"""

# pylint: disable=import-outside-toplevel
from tvm.auto_scheduler.search_task import (
register_task_input_buffer,
) # lazily import to avoid recursive dependency

memo = SparseAnalysisResult(weight_name=[], weight_shape=[])
weight_names = _search_dense_op_weight(expr)
for name in weight_names:
Expand All @@ -92,6 +98,23 @@ def process_params(expr, params, block_size, sparsity_threshold):
params[name + ".data"] = tvm.nd.array(sparse_weight.data)
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)

prefix = "sparse_dense_bsr_%d_%d_%d_%d_%.2f_" % (
w_np.shape[0],
w_np.shape[1],
block_size[0],
block_size[1],
1 - sparsity,
)
register_task_input_buffer(
"default", prefix + "W_data", tvm.runtime.ndarray.array(sparse_weight.data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets not hard-code it, we can use the {name + ".data", name + ".indices", name + ".indptr"}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that we cannot get the "name" during measuring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, thanks for clarification. But i just wonder if name is not available than, how the logic above prefix is working (i mean the line number 98). Its in the same flow right ? Please let me know in case i am mistaken.

)
register_task_input_buffer(
"default", prefix + "W_indices", tvm.runtime.ndarray.array(sparse_weight.indices)
)
register_task_input_buffer(
"default", prefix + "W_indptr", tvm.runtime.ndarray.array(sparse_weight.indptr)
)
ret = SparseAnalysisResult(
weight_name=tvm.runtime.convert(memo.weight_name),
weight_shape=tvm.runtime.convert(memo.weight_shape),
Expand Down
34 changes: 33 additions & 1 deletion python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def _process_inputs(input_tensors, m, n, prefix_init):
density *= i
density /= k * n
density = density.value
sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % (prefix_init, m, n, k, bs_r, bs_c, density)
sparse_prefix = "%s_%d_%d_%d_%d_%.2f_" % (prefix_init, n, k, bs_r, bs_c, density)

visited = set()

Expand Down Expand Up @@ -468,3 +468,35 @@ def _traverse(t):
sparse_input_map[sparse_indptr] = sparse_prefix + "W_indptr"

return sparse_input_map


def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype):
Copy link
Contributor

@ANSHUMAN87 ANSHUMAN87 Mar 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this should not be part of Topi. Either you can put where it is used or I testing.

Copy link
Contributor Author

@jcf94 jcf94 Mar 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this should not be part of Topi. Either you can put where it is used or I testing.

Fine, just I'm finding that this has been used in many different places. I'll try to find a better postion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to topi/sparse/utils.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry for such delayed response, i missed your reply somehow. What my suggestion is, random_bsr_matrix() does not qualify to be in Topi unless it is required by some Ops. What i could see it is just utility for Tutorial, so lets keep this utility func in Tutorial file itself. Otherwise we have one more option, we can put it as part of tvm.testing which can help other tutorials and testcases as well.

"""Generate a random sparse matrix in bsr format.

Returns
-------
scipy.sparse.bsr_matrix
"""
# pylint: disable=import-outside-toplevel
import numpy as np
import itertools
import scipy.sparse as sp

y = np.zeros((m, n), dtype=dtype)
assert m % bs_r == 0
assert n % bs_c == 0
nnz = int(density * m * n)
num_blocks = int(nnz / (bs_r * bs_c)) + 1
candidate_blocks = np.asarray(list(itertools.product(range(0, m, bs_r), range(0, n, bs_c))))
assert candidate_blocks.shape[0] == m // bs_r * n // bs_c
chosen_blocks = candidate_blocks[
np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)
]
# pylint: disable=invalid-name
for (r, c) in chosen_blocks:
y[r : r + bs_r, c : c + bs_c] = np.random.randn(bs_r, bs_c)
s = sp.bsr_matrix(y, blocksize=(bs_r, bs_c))
assert s.data.shape == (num_blocks, bs_r, bs_c)
assert s.indices.shape == (num_blocks,)
assert s.indptr.shape == (m // bs_r + 1,)
return s
43 changes: 42 additions & 1 deletion tutorials/auto_scheduler/tune_network_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"""
Auto-scheduling a Neural Network for x86 CPU
============================================
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, \
`Chengfan Jia <https://github.com/jcf94/>`_

Auto-tuning for specific devices and workloads is critical for getting the
best performance. This is a tutorial on how to tune a whole neural
Expand Down Expand Up @@ -48,6 +49,8 @@

import tvm
from tvm import relay, auto_scheduler
from tvm.relay import data_dep_optimization as ddo
from tvm.topi.nn.sparse import random_bsr_matrix
import tvm.relay.testing
from tvm.contrib import graph_runtime

Expand Down Expand Up @@ -126,6 +129,44 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"):
net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
)
mod = tvm.IRModule.from_expr(net)
elif name == "mlp":
mod, params = relay.testing.mlp.get_workload(
batch_size=batch_size, dtype=dtype, image_shape=image_shape, num_classes=1000
)
elif name == "mlp-sparse":
# This is a test workload that manually transforms a dense model to sparse
# Check `tutorials/frontend/deploy_sparse.py` for more examples on how to import a
# pretrained model.

def random_sparse_params(func, params, density, BS_R, BS_C):
def deepcopy(param_dic):
ret = {}
for k, v in param_dic.items():
ret[k] = tvm.nd.array(v.asnumpy())
return ret

new_params = deepcopy(params)
dense_weight_names = relay.analysis.sparse_dense._search_dense_op_weight(func)
for item in dense_weight_names:
name = str(item)
shape = new_params[name].shape
if shape[0] % BS_R == 0 and shape[1] % BS_C == 0:
new_w = random_bsr_matrix(
shape[0], shape[1], BS_R, BS_C, density, "float32"
).todense()
new_params[name] = tvm.nd.array(new_w)
return new_params

bs_r = 1
sparsity = 0.85

mod, params = relay.testing.mlp.get_workload(
batch_size=batch_size, dtype=dtype, image_shape=image_shape, num_classes=1000
)
mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params)
params = random_sparse_params(mod, params, BS_R=bs_r, BS_C=1, density=1 - sparsity)
mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, 1), sparsity_threshold=0.8)
mod = tvm.IRModule.from_expr(mod)

return mod, params, input_shape, output_shape

Expand Down
33 changes: 5 additions & 28 deletions tutorials/auto_scheduler/tune_sparse_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@
"""

import os
import itertools

import numpy as np
import tvm
from tvm import te, auto_scheduler, runtime, topi
from tvm.auto_scheduler import _ffi_api
from tvm.topi.utils import get_const_tuple

import scipy.sparse as sp
from tvm.topi.nn.sparse import random_bsr_matrix

######################################################################
# Define the computation
Expand All @@ -53,29 +51,6 @@
# The function should return the list of input/output tensors.
# From these tensors, the auto-scheduler can get the whole computational graph.

# We use this function to generate a random bsr matrix
def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype):
import itertools

Y = np.zeros((M, N), dtype=dtype)
assert M % BS_R == 0
assert N % BS_C == 0
nnz = int(density * M * N)
num_blocks = int(nnz / (BS_R * BS_C)) + 1
candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C))))
assert candidate_blocks.shape[0] == M // BS_R * N // BS_C
chosen_blocks = candidate_blocks[
np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)
]
for i in range(len(chosen_blocks)):
r, c = chosen_blocks[i]
Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C)
s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C))
assert s.data.shape == (num_blocks, BS_R, BS_C)
assert s.indices.shape == (num_blocks,)
assert s.indptr.shape == (M // BS_R + 1,)
return s


@auto_scheduler.register_workload
def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype):
Expand Down Expand Up @@ -104,7 +79,9 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype):
# See the `tvm.auto_scheduler.measure.py` for more details.

# Define the basic shapes of this sparse computation
M = K = N = 512
M = 128
K = 256
N = 512
BS_R = 16
BS_C = 1
density = 0.6
Expand All @@ -131,7 +108,7 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype):
target = tvm.target.Target("llvm")

# Register the sparse data to task inputs
prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density)
prefix = "sparse_dense_bsr_%d_%d_%d_%d_%.2f_" % (N, K, BS_R, BS_C, density)
task = tvm.auto_scheduler.SearchTask(
func=sparse_dense,
args=(M, N, K, W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, "float32"),
Expand Down