From 95fd68661cc5153f87362e098a9d2ddb516a00ec Mon Sep 17 00:00:00 2001 From: Tantalus Date: Tue, 1 Jun 2021 16:32:28 +0800 Subject: [PATCH 01/21] [topi] add spconv2d_3x3 nhwc --- python/tvm/topi/x86/sparse.py | 101 +++++++++++++++++++++++++++++++++- 1 file changed, 100 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index c6300f6701e0..4f36ce23dbfd 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -16,8 +16,10 @@ # under the License. """sparse_dense schedule on x86""" -from tvm import te +from tvm import te, tir, autotvm +from functools import partial, reduce +from ..transform import reshape from ..utils import traverse_inline, get_const_int from .utils import get_fp32_len @@ -60,3 +62,100 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +@autotvm.register_topi_compute('conv3x3_spNHWC.x86') +def spconv2d_3x3_nhwc(cfg, Data, Wdat, Wind, Wptr): + '''# My SpConv2d_3x3_gemm + + Data: N,H,W,C -> NHW,33C + Weight: F,3,3,C -> F,33C + + yt, xt, yo => + yi, k9, ci:vec => + @im2col = {yt, yo, yi}/y, {k9, ci}/k + xo => + x1:1, ko:dyn(xr), yi:unroll, xi:vec, ki:unroll => + @CC = {yt, yo, yi}/y, {xt, xo, x1}/xr, xi, ki // ko + yi:unroll, xi:vec, ki:unroll => + @C = {yt, yo, yi}/y, {xt, xo, xi}/x // ki + ''' + N, H, W, CI = [i.value for i in Data.shape] + nElems, bsrR, bsrC = [i.value for i in Wdat.shape] + CO = (Wptr.shape[0].value - 1) * bsrR + + Y, X, K = N*H*W, CO, 9*CI + # cfg = autotvm.get_config() + cfg.define_split("tile_y", Y, num_outputs=3) + cfg.define_split("tile_x", X // bsrR, num_outputs=2) + cfg.add_flop(Y * (nElems * bsrC * bsrR * 2 - X)) + #cfg.define_split("tile_k", K, num_outputs=2) + if cfg.is_fallback: + cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 160, 8]) + cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 4]) + + idxsplit = lambda x,y: reduce(lambda a,b: a[:-1]+[a[-1]%b,a[-1]//b], y, [x]) + + @partial(te.compute, (Y, K), name='Im2Col') + def Im2Col(row, col): + jw, jh, jn = idxsplit(row, [W, H]) + jc, kw, kh = idxsplit(col, [CI, 3]) + ih, iw = jh + kh - 1, jw + kw - 1 + return tir.if_then_else( + tir.all(0 <= ih, ih < H, 0 <= iw, iw < W), + Data[jn, ih, iw, jc], 0) + + @partial(te.compute, (Y, X // bsrR, bsrR, bsrC), name='CC') + def CC(drow, wrow, brow, bcol): + row_start, row_end = Wptr[wrow], Wptr[wrow+1] + elem_idx = te.reduce_axis((0, row_end - row_start), name='elem_idx') + elem = row_start + elem_idx + return te.sum(Im2Col[drow, Wind[elem]*bsrC + bcol] * Wdat[elem, brow, bcol], axis=elem_idx) + + k = te.reduce_axis((0, bsrC), name='k') + C = te.compute((Y, X), + lambda y, x: te.sum(CC[y, x // bsrR, x % bsrR, k], axis=k), + name='C', tag='conv3x3_spNHWC') + return reshape(C, (N, H, W, CO)) + + +@autotvm.register_topi_schedule('conv3x3_spNHWC.x86') +def schedule_spconv2d_3x3_nhwc(cfg, outs): + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'conv3x3_spNHWC': + C = op + CC, = op.input_tensors + Wptr, Wind, Im2Col, Wdat = CC.op.input_tensors + Data, = Im2Col.op.input_tensors + bsrR = CC.shape[-2].value + CI = Data.shape[-1].value + + y, x = s[C].op.axis + yt, yo, yi = cfg['tile_y'].apply(s, C, y) + xo, xi = s[C].split(x, factor=bsrR) + xt, xo = cfg['tile_x'].apply(s, C, xo) + (k,) = s[C].op.reduce_axis + s[C].reorder(yt, xt, yo, xo, yi, xi, k) + s[C].unroll(k) + s[C].vectorize(xi) + s[C].unroll(yi) + + s[CC].compute_at(s[C], xo) + yi, xi, r, c = s[CC].op.axis + (k,) = s[CC].op.reduce_axis + s[CC].reorder(xi, k, yi, r, c) + s[CC].unroll(c) + s[CC].vectorize(r) + s[CC].unroll(yi) + + s[Im2Col].compute_at(s[C], yo) + yi, k = s[Im2Col].op.axis + ko, ki = s[Im2Col].split(k, factor=CI) + s[Im2Col].vectorize(ki) + #s[Im2Col].unroll(yi) + + traverse_inline(s, outs[0].op, _callback) + return s From 756d6be9a45c4a203eecd5d1004cd36d98d3645b Mon Sep 17 00:00:00 2001 From: Tantalus Date: Sat, 5 Jun 2021 12:30:54 +0800 Subject: [PATCH 02/21] [relay] sparse_conv2d: add kernel_size attr --- include/tvm/relay/attrs/nn.h | 3 ++ python/tvm/relay/analysis/sparse_conv2d.py | 53 ++++++++++++------- .../relay/data_dep_optimization/bsr_conv2d.py | 6 +-- python/tvm/relay/op/nn/_nn.py | 2 +- python/tvm/relay/transform/transform.py | 4 +- python/tvm/topi/nn/sparse.py | 2 +- src/relay/op/nn/sparse.cc | 3 +- src/relay/transforms/convert_sparse_conv2d.cc | 14 +++-- 8 files changed, 54 insertions(+), 33 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 694001f612e7..f4eff9120e4a 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1047,12 +1047,15 @@ struct SparseTransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes for sparse_dense operator */ struct SparseConv2DAttrs : public tvm::AttrsNode { std::string layout; + int kernel_size; TVM_DECLARE_ATTRS(SparseConv2DAttrs, "relay.attrs.SparseConv2DAttrs") { TVM_ATTR_FIELD(layout).set_default("NHWC").describe( "Dimension ordering of input data. Can be 'NCHW', 'NHWC'" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively."); + TVM_ATTR_FIELD(kernel_size).set_default(1).describe( + "Kernel size for SparseConv2D, 1x1 or 3x3. "); } }; diff --git a/python/tvm/relay/analysis/sparse_conv2d.py b/python/tvm/relay/analysis/sparse_conv2d.py index 11278bddca33..07dc17611559 100644 --- a/python/tvm/relay/analysis/sparse_conv2d.py +++ b/python/tvm/relay/analysis/sparse_conv2d.py @@ -54,7 +54,7 @@ def _search_conv2d_op_weight(expr): return _ffi_api.search_conv2d_op_weight(expr) -def process_params(expr, params, block_size, sparsity_threshold, layout): +def process_params(expr, params, block_size, sparsity_threshold, layout, kernel_size, reg_task_input=True): """Process parameters of conv2d from dense to sparse. Parameters @@ -86,14 +86,18 @@ def process_params(expr, params, block_size, sparsity_threshold, layout): for name in weight_names: name = str(name) w_np = params[name].numpy() - # currently only support conv2d_1*1 - if not ( - (w_np.shape[0] == 1 and w_np.shape[1] == 1) - or (w_np.shape[2] == 1 and w_np.shape[3] == 1) - ): + + if layout == "NHWC": # HWIO + weight_kernel = (w_np.shape[0], w_np.shape[1]) + elif layout == "NCHW": # OIHW + weight_kernel = (w_np.shape[2], w_np.shape[3]) + if weight_kernel[0] != weight_kernel[1]: continue - sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size) - if sparsity >= sparsity_threshold: + + if weight_kernel[0] == kernel_size == 1: + sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size) + if sparsity < sparsity_threshold: + continue if layout == "NHWC": w_np = w_np.squeeze().T elif layout == "NCHW": @@ -108,19 +112,28 @@ def process_params(expr, params, block_size, sparsity_threshold, layout): ) else: sparse_weight_data = sparse_weight.data + elif weight_kernel[0] == kernel_size == 3 and layout == "NHWC": + w_np = w_np.reshape((-1, w_np.shape[-1])).T + sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size) + if 1 - (sparse_weight.nnz / w_np.size) < sparsity_threshold: + continue + sparse_weight_data = sparse_weight.data + else: + continue - # remove dense weight - del params[name] - memo.weight_name.append(name) - memo.weight_shape.append( - list(sparse_weight_data.shape) - + list(sparse_weight.indices.shape) - + list(sparse_weight.indptr.shape) - ) - params[name + ".data"] = tvm.nd.array(sparse_weight_data) - params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) - params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) - + # remove dense weight + del params[name] + memo.weight_name.append(name) + memo.weight_shape.append( + list(sparse_weight_data.shape) + + list(sparse_weight.indices.shape) + + list(sparse_weight.indptr.shape) + ) + params[name + ".data"] = tvm.nd.array(sparse_weight_data) + params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) + params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) + + if reg_task_input: prefix = "sparse_conv2d_bsr_%d_%d_%d_%d_%d_%d_" % ( w_np.shape[0], w_np.shape[1], diff --git a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py index 6913a428b2ac..1bb3a687d9de 100644 --- a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py +++ b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py @@ -23,7 +23,7 @@ from .utils import _run_opt_pass -def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"): +def convert(func, params, blocksize, sparsity_threshold, layout="NHWC", kernel_size=1): """Convert a dense func and according parameters to block sparse Parameters @@ -49,10 +49,10 @@ def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"): params: Dict[Srting, tvm.nd.array] New params with BSR matrix for mutated Expr """ - weight_info = process_params(func, params, blocksize, sparsity_threshold, layout) + weight_info = process_params(func, params, blocksize, sparsity_threshold, layout, kernel_size) new_func = _run_opt_pass( func, - relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout), + relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout, kernel_size), ) return new_func, params diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 96cef8bc3588..f9d63c9ec4c8 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -198,7 +198,7 @@ def compute_sparse_transpose(attrs, inputs, out_type): @reg.register_compute("nn.sparse_conv2d") def compute_sparse_conv2d(attrs, inputs, out_type): """Compute definition of sparse_conv2d""" - return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"])] + return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"], attrs['kernel_size'])] reg.register_strategy("nn.sparse_conv2d", strategy.sparse_conv2d_strategy) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 6294e7acea15..174a0c511633 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1093,7 +1093,7 @@ def DenseToSparse(weight_name, weight_shape): return _ffi_api.DenseToSparse(weight_name, weight_shape) -def Conv2dToSparse(weight_name, weight_shape, layout): +def Conv2dToSparse(weight_name, weight_shape, layout, kernel_size): """ Rewrite qualified ```nn.conv2d operation``` to ```nn.sparse_conv2d``` @@ -1113,7 +1113,7 @@ def Conv2dToSparse(weight_name, weight_shape, layout): ret : tvm.transform.Pass The registered DenseToSparse pass. """ - return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout) + return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout, kernel_size) def SimplifyFCTranspose(target_weight_name): diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 948847e60d92..b294e09b04cf 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -566,7 +566,7 @@ def _compute_block(i, nb_j, j, h, w): # pylint: disable=C0103 ) -def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC"): +def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC", kernel_size=1): """ Computes sparse-conv2d(1*1) of ``data`` and ``(weight_data, weight_indices, weight_indptr)`` diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 32b0811b48ac..99fb4699123d 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -274,10 +274,11 @@ bool SparseConv2dRel(const Array& types, int num_inputs, const Attrs& attr } Expr MakeSparseConv2d(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr, - std::string layout) { + std::string layout, int kernel_size) { static const Op& op = Op::Get("nn.sparse_conv2d"); auto attrs = make_object(); attrs->layout = std::move(layout); + attrs->kernel_size = kernel_size; return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index 6e4c03b0fcbc..566b92518e90 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -73,10 +73,12 @@ TVM_REGISTER_GLOBAL("relay.analysis.search_conv2d_op_weight").set_body_typed(Sea class Conv2dToSparseConv2dMutator : public ExprRewriter { public: Conv2dToSparseConv2dMutator(const Array& weight_name, - const Array>& weight_shape, const String& layout) + const Array>& weight_shape, + const String& layout, int kernel_size) : conv2d_op_(Op::Get("nn.conv2d")), sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")) { ICHECK_EQ(weight_name.size(), weight_shape.size()); layout_ = layout; + kernel_size_ = kernel_size; for (size_t i = 0; i < weight_name.size(); ++i) { ICHECK(weight_name[i]->IsInstance()); std::string k = weight_name[i].as()->data; @@ -112,6 +114,7 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { Var weight_indptr(prefix + ".indptr", ws_indptr_type); auto attrs = make_object(); attrs->layout = std::move(layout_); + attrs->kernel_size = kernel_size_; return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs)); } @@ -126,22 +129,23 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { const Op& sparse_conv2d_op_; std::unordered_map> target_weights_; String layout_; + int kernel_size_; }; // class Conv2dToSparseConv2dAlter Expr Conv2dToSparse(const Expr& e, const Array& weight_name, - const Array>& weight_shape, const String& layout) { - auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout); + const Array>& weight_shape, const String& layout, int kernel_size) { + auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout, kernel_size); return PostOrderRewrite(e, &rewriter); } namespace transform { Pass Conv2dToSparse(const Array& weight_name, const Array>& weight_shape, - const String& layout) { + const String& layout, int kernel_size) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { // Remove FreeVar warnings - auto f0 = Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout)); + auto f0 = Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); Array sparse_params = FreeVars(f0); auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); From b973456fb6b9d8982c4e2aea823dfa52c1e007a4 Mon Sep 17 00:00:00 2001 From: Tantalus Date: Fri, 2 Jul 2021 13:55:09 +0800 Subject: [PATCH 03/21] [relay] add strategy for spconv2d_3x3 nhwc --- python/tvm/relay/op/strategy/x86.py | 19 +++++++++++++++++++ python/tvm/topi/nn/sparse.py | 17 +++++++++-------- python/tvm/topi/x86/sparse.py | 2 +- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index a6e141f2753b..ddc404aa2787 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -565,6 +565,25 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target): return strategy +@sparse_conv2d_strategy.register("cpu") +def sparse_conv2d_strategy_cpu(attrs, inputs, out_type, target): + """sparse conv2d x86 strategy""" + strategy = _op.OpStrategy() + if attrs["kernel_size"] == 1: + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.nn.sparse_conv2d), + wrap_topi_schedule(topi.generic.schedule_sparse_conv2d), + name="sparse_conv2d.generic", + ) + elif attrs["kernel_size"] == 3 and attrs["layout"] == "NHWC": + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nhwc), + wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nhwc), + name="conv3x3_spNHWC.x86", + ) + return strategy + + @roi_align_strategy.register("cpu") def roi_align_strategy_cpu(attrs, inputs, out_type, target): """roi_align x86 strategy""" diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index b294e09b04cf..c6b50fc3c8c5 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -598,14 +598,15 @@ def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout 4-D with shape [M, H, W, N] (layout=NHWC) 4-D with shape [M, N, H ,W] (layout=NCHW) """ - if layout == "NHWC": - return _sparse_conv2d_bsr_compute_nhwc( - dense_data, sparse_data, sparse_indices, sparse_indptr - ) - elif layout == "NCHW": - return _sparse_conv2d_bsr_compute_nchw( - dense_data, sparse_data, sparse_indices, sparse_indptr - ) + if kernel_size == 1: + if layout == "NHWC": + return _sparse_conv2d_bsr_compute_nhwc( + dense_data, sparse_data, sparse_indices, sparse_indptr + ) + elif layout == "NCHW": + return _sparse_conv2d_bsr_compute_nchw( + dense_data, sparse_data, sparse_indices, sparse_indptr + ) else: raise ValueError("Unsupport Layout %s" % layout) diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index 4f36ce23dbfd..6f7e15c22441 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -65,7 +65,7 @@ def _callback(op): @autotvm.register_topi_compute('conv3x3_spNHWC.x86') -def spconv2d_3x3_nhwc(cfg, Data, Wdat, Wind, Wptr): +def spconv2d_3x3_nhwc(cfg, Data, Wdat, Wind, Wptr, layout="NHWC"): '''# My SpConv2d_3x3_gemm Data: N,H,W,C -> NHW,33C From ae4a200f84847ac80d8cacae6b9b49989c07453f Mon Sep 17 00:00:00 2001 From: Tantalus Date: Wed, 7 Jul 2021 12:33:34 +0800 Subject: [PATCH 04/21] [relay] pass to convert spconv2d with const args --- python/tvm/autotvm/measure/measure_methods.py | 15 +- src/relay/transforms/convert_sparse_conv2d.cc | 169 ++++++++++++++++++ 2 files changed, 177 insertions(+), 7 deletions(-) diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index db4ff26857bd..eab6822b63b8 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -254,13 +254,14 @@ def ref_input(self): @ref_input.setter def ref_input(self, val): - warnings.warn( - "You are specifying fixed input for tuning the operator. " - "Be sure your input always fits the operator. Some " - "operators may conduct layout transformation during tuning, " - "thus can lead to unexpected behaviors. ", - RuntimeWarning, - ) + if val is not None: + warnings.warn( + "You are specifying fixed input for tuning the operator. " + "Be sure your input always fits the operator. Some " + "operators may conduct layout transformation during tuning, " + "thus can lead to unexpected behaviors. ", + RuntimeWarning, + ) self._ref_input = val def set_task(self, task): diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index 566b92518e90..e20d3a4f8960 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -138,6 +138,163 @@ Expr Conv2dToSparse(const Expr& e, const Array& weight_name, return PostOrderRewrite(e, &rewriter); } + +template +auto unpack_to_tuple_internal(elemTy *arr, std::index_sequence) { + return std::make_tuple(arr[Is]...); +} + +template +auto unpack_to_tuple(elemTy *arr) { + return unpack_to_tuple_internal(arr, std::make_index_sequence{}); +} + +struct Range { + size_t dim; + Range(size_t d): dim(d) {} + + struct iterpoint { + size_t val, lim; + iterpoint(size_t v1, size_t v2): val(v1), lim(v2) {} + + size_t operator*() const { + return val; + } + + iterpoint operator/(const iterpoint &rhs) const { + return iterpoint(val * rhs.lim + rhs.val, lim * rhs.lim); + } + }; + + struct iterator { + size_t val, lim; + iterator(size_t v1, size_t v2): val(v1), lim(v2) {} + + bool operator!=(const iterator &rhs) const { + return val != rhs.val; + } + + void operator++() { + ++val; + } + + iterpoint operator*() const { + return iterpoint(val, lim); + } + }; + + iterator begin() { + return iterator(0, dim); + } + + iterator end() { + return iterator(dim, dim); + } +}; + + +// Mutate ```nn.conv2d``` to ```nn.sparse_conv2d``` +class Conv2dToSparseConv2dMutator2 : public ExprRewriter { + public: + Conv2dToSparseConv2dMutator2(const String& layout, int kernel_size, + int blockH, int blockW, double sparse_thresh) + : sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")), dev_cpu0_{DLDeviceType::kDLCPU, 0}, + layout_(layout), kernel_size_(kernel_size), + blockH_(blockH), blockW_(blockW), sparse_thresh_(sparse_thresh) { } + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + // check op type & attrs + const auto pre_attrs = pre->attrs.as(); + if (!pre_attrs + || pre_attrs->data_layout != layout_ + || pre_attrs->strides[0].as()->value != 1 + || pre_attrs->kernel_size[0].as()->value != kernel_size_) + return post; + // check constant weight + const auto pre_weight_node = pre->args[1].as(); + if (!pre_weight_node) + return post; + + // check weight dtype & shape + auto &&pre_weight = pre_weight_node->data; + auto dtype = pre_weight.DataType(), itype = runtime::DataType::Int(32); + ICHECK(dtype.code() == DataType::kFloat && dtype.bits() == 32); // float32 only + auto pre_weight_shape = unpack_to_tuple<4>(pre_weight.Shape().data()); + int O, I, H, W; + if (layout_ == "NCHW") { + std::tie(O, I, H, W) = pre_weight_shape; + } else { // NHWC + std::tie(H, W, I, O) = pre_weight_shape; + } + int CO = O, CI = H * W * I; + + // copy to vector + std::vector pre_weight_data(CO * CI); + pre_weight.CopyToBytes(pre_weight_data.data(), pre_weight_data.size() * sizeof(float)); + if (layout_ == "NHWC") { + std::vector tmp(pre_weight_data.size()); + for (auto i: Range(CO)) + for (auto j: Range(CI)) + tmp[*(i / j)] = pre_weight_data[*(j / i)]; + std::swap(tmp, pre_weight_data); + } + // convert to BSR + std::vector wdata, block(blockH_ * blockW_); + std::vector windices, windptr; + for (auto bh: Range(CO / blockH_)) { + windptr.push_back(windices.size()); + for (auto bw: Range(CI / blockW_)) { + int cntnnz = 0; + for (auto i: Range(blockH_)) + for (auto j: Range(blockW_)) { + auto tmp = pre_weight_data[*(bh / i / bw / j)]; + if (tmp) cntnnz++; + block[*(i / j)] = tmp; + } + if (cntnnz) { + wdata.insert(wdata.end(), block.begin(), block.end()); + windices.push_back(*bw); + } + } + } + windptr.push_back(windices.size()); + double sprate = 1 - 1.0 * wdata.size() / pre_weight_data.size(); + if (sprate < sparse_thresh_) return post; + + // constrct return data + int nnz = windices.size(); + auto weight_data = runtime::NDArray::Empty({nnz, blockH_, blockW_}, dtype, dev_cpu0_); + auto weight_indices = runtime::NDArray::Empty({nnz}, itype, dev_cpu0_); + auto weight_indptr = runtime::NDArray::Empty({CO / blockH_ + 1}, itype, dev_cpu0_); + weight_data.CopyFromBytes(wdata.data(), wdata.size() * sizeof(float)); + weight_indices.CopyFromBytes(windices.data(), windices.size() * sizeof(int32_t)); + weight_indptr.CopyFromBytes(windptr.data(), windptr.size() * sizeof(int32_t)); + + // construct return call + auto args = runtime::Array { + post.as()->args[0], + Constant(weight_data), Constant(weight_indices), Constant(weight_indptr) + }; + auto attrs = make_object(); + attrs->layout = layout_; + attrs->kernel_size = kernel_size_; + return Call(sparse_conv2d_op_, args, Attrs(attrs)); + } + +private: + const Op &sparse_conv2d_op_; + DLDevice dev_cpu0_; + String layout_; + int kernel_size_, blockH_, blockW_; + double sparse_thresh_; +}; // class Conv2dToSparseConv2dMutator2 + +Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int blockH, int blockW, double sparse_thresh) { + auto rewriter = Conv2dToSparseConv2dMutator2(layout, kernel_size, blockH, blockW, sparse_thresh); + return PostOrderRewrite(e, &rewriter); +} + + namespace transform { Pass Conv2dToSparse(const Array& weight_name, const Array>& weight_shape, @@ -159,6 +316,18 @@ Pass Conv2dToSparse(const Array& weight_name, const Array pass_func = + [=](Function f, IRModule m, PassContext pc) { + auto f0 = Downcast(Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh)); + return f0; + }; + return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2", {"DeadCodeElimination"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse2").set_body_typed(Conv2dToSparse2); + } // namespace transform } // namespace relay From ef222a9fded246630f7e0ebdf336c632c00e9705 Mon Sep 17 00:00:00 2001 From: Tantalus Date: Tue, 20 Jul 2021 14:21:23 +0800 Subject: [PATCH 05/21] [topi] add sparse_conv2d 3x3 nchw --- python/tvm/topi/x86/sparse.py | 62 +++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index 6f7e15c22441..c4f9a3fd784b 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -159,3 +159,65 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +@autotvm.register_topi_compute('conv3x3_spNCHW.x86') +def spconv2d_3x3_nchw(cfg, Data, Wdat, Wind, Wptr): + N, CI, H, W = [i.value for i in Data.shape] + NNZ, VL, bsrC = [i.value for i in Wdat.shape] + CO = (Wptr.shape[0].value - 1) * VL + assert bsrC == 1 + + cfg.add_flop(N*H*W * (NNZ * VL * bsrC * 2 - CO)) + cfg.define_split("tile_hw", H*W, num_outputs=3) + cfg.define_split("tile_ckk", CI*9, num_outputs=3) + + @partial(te.compute, (N, CI*3*3, H*W), name='im2col') + def Im2Col(n, ckk, hw): + jh, jw = hw // W, hw % W + ic, kh, kw = ckk // 9, ckk // 3 % 3, ckk % 3 + ih, iw = jh + kh - 1, jw + kw - 1 + return tir.if_then_else( + tir.all(0 <= ih, ih < H, 0 <= iw, iw < W), + Data[n, ic, ih, iw], 0) + + @partial(te.compute, (N, CO // VL, VL, bsrC, H*W), name='CC', tag='conv3x3_spNCHW') + def CC(n, fo, fi, k, hw): + row_start, row_end = Wptr[fo], Wptr[fo+1] + elem_idx = te.reduce_axis((0, row_end - row_start), name='elem_idx') + elem = row_start + elem_idx + return te.sum(Im2Col[n, Wind[elem] * bsrC + k, hw] * Wdat[elem, fi, k], + axis=elem_idx) + + return reshape(CC, [N, CO, H, W]) + + +@autotvm.register_topi_schedule('conv3x3_spNCHW.x86') +def schedule_spconv2d_3x3_nchw(cfg, outs): + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'conv3x3_spNCHW': + CC = op + Wptr, Wind, im2col, Wdat = op.input_tensors + Data, = im2col.op.input_tensors + + n, fo, fi, bc, hw = s[CC].op.axis + kk, = s[CC].op.reduce_axis + hw1, hw2, hw3 = cfg["tile_hw"].apply(s, CC, hw) + s[CC].reorder(n, hw1, fo, hw2, kk, fi, bc, hw3) + s[CC].unroll(fi) + s[CC].unroll(bc) + s[CC].vectorize(hw3) + + s[im2col].compute_at(s[CC], hw1) + n, ckk, hw = s[im2col].op.axis + ckk1, ckk2, ckk3 = cfg["tile_ckk"].apply(s, im2col, ckk) + hw2, hw3 = s[im2col].split(hw, factor=cfg["tile_hw"].size[-1]) + s[im2col].reorder(n, ckk1, ckk2, hw2, ckk3, hw3) + s[im2col].unroll(ckk3) + s[im2col].vectorize(hw3) + + traverse_inline(s, outs[0].op, _callback) + return s \ No newline at end of file From bc28f46d9ffc609bf339166eaced0e136d86a952 Mon Sep 17 00:00:00 2001 From: Tantalus Date: Tue, 20 Jul 2021 14:41:08 +0800 Subject: [PATCH 06/21] [relay] add strategy for sparse_conv2d 3x3 nchw --- python/tvm/relay/op/nn/_nn.py | 2 +- python/tvm/relay/op/strategy/x86.py | 18 ++++++---- python/tvm/topi/x86/sparse.py | 54 ++++++++++------------------- 3 files changed, 32 insertions(+), 42 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index f9d63c9ec4c8..c6a6c6ca8221 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -198,7 +198,7 @@ def compute_sparse_transpose(attrs, inputs, out_type): @reg.register_compute("nn.sparse_conv2d") def compute_sparse_conv2d(attrs, inputs, out_type): """Compute definition of sparse_conv2d""" - return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"], attrs['kernel_size'])] + return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"], attrs["kernel_size"])] reg.register_strategy("nn.sparse_conv2d", strategy.sparse_conv2d_strategy) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index ddc404aa2787..710d545f69f5 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -575,12 +575,18 @@ def sparse_conv2d_strategy_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.generic.schedule_sparse_conv2d), name="sparse_conv2d.generic", ) - elif attrs["kernel_size"] == 3 and attrs["layout"] == "NHWC": - strategy.add_implementation( - wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nhwc), - wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nhwc), - name="conv3x3_spNHWC.x86", - ) + elif attrs["kernel_size"] == 3: + if attrs["layout"] == "NHWC": + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nhwc), + wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nhwc), + name="conv3x3_spNHWC.x86", + ) + elif attrs["layout"] == "NCHW": + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nchw), + wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nchw), + ) return strategy diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index c4f9a3fd784b..be24d7ef5d96 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -64,39 +64,23 @@ def _callback(op): return s -@autotvm.register_topi_compute('conv3x3_spNHWC.x86') +@autotvm.register_topi_compute("conv3x3_spNHWC.x86") def spconv2d_3x3_nhwc(cfg, Data, Wdat, Wind, Wptr, layout="NHWC"): - '''# My SpConv2d_3x3_gemm - - Data: N,H,W,C -> NHW,33C - Weight: F,3,3,C -> F,33C - - yt, xt, yo => - yi, k9, ci:vec => - @im2col = {yt, yo, yi}/y, {k9, ci}/k - xo => - x1:1, ko:dyn(xr), yi:unroll, xi:vec, ki:unroll => - @CC = {yt, yo, yi}/y, {xt, xo, x1}/xr, xi, ki // ko - yi:unroll, xi:vec, ki:unroll => - @C = {yt, yo, yi}/y, {xt, xo, xi}/x // ki - ''' N, H, W, CI = [i.value for i in Data.shape] nElems, bsrR, bsrC = [i.value for i in Wdat.shape] CO = (Wptr.shape[0].value - 1) * bsrR Y, X, K = N*H*W, CO, 9*CI - # cfg = autotvm.get_config() cfg.define_split("tile_y", Y, num_outputs=3) cfg.define_split("tile_x", X // bsrR, num_outputs=2) cfg.add_flop(Y * (nElems * bsrC * bsrR * 2 - X)) - #cfg.define_split("tile_k", K, num_outputs=2) if cfg.is_fallback: - cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 160, 8]) - cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 4]) + cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 160, 8]) + cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 4]) idxsplit = lambda x,y: reduce(lambda a,b: a[:-1]+[a[-1]%b,a[-1]//b], y, [x]) - @partial(te.compute, (Y, K), name='Im2Col') + @partial(te.compute, (Y, K), name="Im2Col") def Im2Col(row, col): jw, jh, jn = idxsplit(row, [W, H]) jc, kw, kh = idxsplit(col, [CI, 3]) @@ -105,27 +89,27 @@ def Im2Col(row, col): tir.all(0 <= ih, ih < H, 0 <= iw, iw < W), Data[jn, ih, iw, jc], 0) - @partial(te.compute, (Y, X // bsrR, bsrR, bsrC), name='CC') + @partial(te.compute, (Y, X // bsrR, bsrR, bsrC), name="CC") def CC(drow, wrow, brow, bcol): row_start, row_end = Wptr[wrow], Wptr[wrow+1] - elem_idx = te.reduce_axis((0, row_end - row_start), name='elem_idx') + elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") elem = row_start + elem_idx return te.sum(Im2Col[drow, Wind[elem]*bsrC + bcol] * Wdat[elem, brow, bcol], axis=elem_idx) - k = te.reduce_axis((0, bsrC), name='k') + k = te.reduce_axis((0, bsrC), name="k") C = te.compute((Y, X), lambda y, x: te.sum(CC[y, x // bsrR, x % bsrR, k], axis=k), - name='C', tag='conv3x3_spNHWC') + name="C", tag="conv3x3_spNHWC") return reshape(C, (N, H, W, CO)) -@autotvm.register_topi_schedule('conv3x3_spNHWC.x86') +@autotvm.register_topi_schedule("conv3x3_spNHWC.x86") def schedule_spconv2d_3x3_nhwc(cfg, outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) def _callback(op): - if op.tag == 'conv3x3_spNHWC': + if op.tag == "conv3x3_spNHWC": C = op CC, = op.input_tensors Wptr, Wind, Im2Col, Wdat = CC.op.input_tensors @@ -134,9 +118,9 @@ def _callback(op): CI = Data.shape[-1].value y, x = s[C].op.axis - yt, yo, yi = cfg['tile_y'].apply(s, C, y) + yt, yo, yi = cfg["tile_y"].apply(s, C, y) xo, xi = s[C].split(x, factor=bsrR) - xt, xo = cfg['tile_x'].apply(s, C, xo) + xt, xo = cfg["tile_x"].apply(s, C, xo) (k,) = s[C].op.reduce_axis s[C].reorder(yt, xt, yo, xo, yi, xi, k) s[C].unroll(k) @@ -161,8 +145,8 @@ def _callback(op): return s -@autotvm.register_topi_compute('conv3x3_spNCHW.x86') -def spconv2d_3x3_nchw(cfg, Data, Wdat, Wind, Wptr): +@autotvm.register_topi_compute("conv3x3_spNCHW.x86") +def spconv2d_3x3_nchw(cfg, Data, Wdat, Wind, Wptr, layout="NCHW"): N, CI, H, W = [i.value for i in Data.shape] NNZ, VL, bsrC = [i.value for i in Wdat.shape] CO = (Wptr.shape[0].value - 1) * VL @@ -172,7 +156,7 @@ def spconv2d_3x3_nchw(cfg, Data, Wdat, Wind, Wptr): cfg.define_split("tile_hw", H*W, num_outputs=3) cfg.define_split("tile_ckk", CI*9, num_outputs=3) - @partial(te.compute, (N, CI*3*3, H*W), name='im2col') + @partial(te.compute, (N, CI*3*3, H*W), name="im2col") def Im2Col(n, ckk, hw): jh, jw = hw // W, hw % W ic, kh, kw = ckk // 9, ckk // 3 % 3, ckk % 3 @@ -181,10 +165,10 @@ def Im2Col(n, ckk, hw): tir.all(0 <= ih, ih < H, 0 <= iw, iw < W), Data[n, ic, ih, iw], 0) - @partial(te.compute, (N, CO // VL, VL, bsrC, H*W), name='CC', tag='conv3x3_spNCHW') + @partial(te.compute, (N, CO // VL, VL, bsrC, H*W), name="CC", tag="conv3x3_spNCHW") def CC(n, fo, fi, k, hw): row_start, row_end = Wptr[fo], Wptr[fo+1] - elem_idx = te.reduce_axis((0, row_end - row_start), name='elem_idx') + elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") elem = row_start + elem_idx return te.sum(Im2Col[n, Wind[elem] * bsrC + k, hw] * Wdat[elem, fi, k], axis=elem_idx) @@ -192,13 +176,13 @@ def CC(n, fo, fi, k, hw): return reshape(CC, [N, CO, H, W]) -@autotvm.register_topi_schedule('conv3x3_spNCHW.x86') +@autotvm.register_topi_schedule("conv3x3_spNCHW.x86") def schedule_spconv2d_3x3_nchw(cfg, outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) def _callback(op): - if op.tag == 'conv3x3_spNCHW': + if op.tag == "conv3x3_spNCHW": CC = op Wptr, Wind, im2col, Wdat = op.input_tensors Data, = im2col.op.input_tensors From 5a0ac03f8545d7d88b3fc6329535fef1dfe5d7f2 Mon Sep 17 00:00:00 2001 From: Tantalus Date: Thu, 5 Aug 2021 16:56:18 +0800 Subject: [PATCH 07/21] [relay] convert sparse conv2d pass fixes --- python/tvm/relay/analysis/sparse_conv2d.py | 7 +++- .../relay/data_dep_optimization/bsr_conv2d.py | 2 +- src/relay/transforms/convert_sparse_conv2d.cc | 37 ++++++++++++------- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/analysis/sparse_conv2d.py b/python/tvm/relay/analysis/sparse_conv2d.py index 07dc17611559..0bb865db536d 100644 --- a/python/tvm/relay/analysis/sparse_conv2d.py +++ b/python/tvm/relay/analysis/sparse_conv2d.py @@ -112,8 +112,11 @@ def process_params(expr, params, block_size, sparsity_threshold, layout, kernel_ ) else: sparse_weight_data = sparse_weight.data - elif weight_kernel[0] == kernel_size == 3 and layout == "NHWC": - w_np = w_np.reshape((-1, w_np.shape[-1])).T + elif weight_kernel[0] == kernel_size == 3: + if layout == "NHWC": # HWIO + w_np = w_np.reshape((-1, w_np.shape[-1])).T + elif layout == "NCHW": # OIHW + w_np = w_np.reshape((w_np.shape[0], -1)) sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size) if 1 - (sparse_weight.nnz / w_np.size) < sparsity_threshold: continue diff --git a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py index 1bb3a687d9de..2133af5711b4 100644 --- a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py +++ b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py @@ -24,7 +24,7 @@ def convert(func, params, blocksize, sparsity_threshold, layout="NHWC", kernel_size=1): - """Convert a dense func and according parameters to block sparse + """Convert a conv2d func and according parameters to block sparse Parameters ---------- diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index e20d3a4f8960..c21d8ca81cb7 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -133,7 +133,8 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { }; // class Conv2dToSparseConv2dAlter Expr Conv2dToSparse(const Expr& e, const Array& weight_name, - const Array>& weight_shape, const String& layout, int kernel_size) { + const Array>& weight_shape, + const String& layout, int kernel_size) { auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout, kernel_size); return PostOrderRewrite(e, &rewriter); } @@ -151,7 +152,7 @@ auto unpack_to_tuple(elemTy *arr) { struct Range { size_t dim; - Range(size_t d): dim(d) {} + explicit Range(size_t d): dim(d) {} struct iterpoint { size_t val, lim; @@ -205,7 +206,7 @@ class Conv2dToSparseConv2dMutator2 : public ExprRewriter { Expr Rewrite_(const CallNode* pre, const Expr& post) override { // check op type & attrs const auto pre_attrs = pre->attrs.as(); - if (!pre_attrs + if (!pre_attrs || pre_attrs->data_layout != layout_ || pre_attrs->strides[0].as()->value != 1 || pre_attrs->kernel_size[0].as()->value != kernel_size_) @@ -233,20 +234,20 @@ class Conv2dToSparseConv2dMutator2 : public ExprRewriter { pre_weight.CopyToBytes(pre_weight_data.data(), pre_weight_data.size() * sizeof(float)); if (layout_ == "NHWC") { std::vector tmp(pre_weight_data.size()); - for (auto i: Range(CO)) - for (auto j: Range(CI)) + for (auto i : Range(CO)) + for (auto j : Range(CI)) tmp[*(i / j)] = pre_weight_data[*(j / i)]; std::swap(tmp, pre_weight_data); } // convert to BSR std::vector wdata, block(blockH_ * blockW_); std::vector windices, windptr; - for (auto bh: Range(CO / blockH_)) { + for (auto bh : Range(CO / blockH_)) { windptr.push_back(windices.size()); - for (auto bw: Range(CI / blockW_)) { + for (auto bw : Range(CI / blockW_)) { int cntnnz = 0; - for (auto i: Range(blockH_)) - for (auto j: Range(blockW_)) { + for (auto i : Range(blockH_)) + for (auto j : Range(blockW_)) { auto tmp = pre_weight_data[*(bh / i / bw / j)]; if (tmp) cntnnz++; block[*(i / j)] = tmp; @@ -281,7 +282,7 @@ class Conv2dToSparseConv2dMutator2 : public ExprRewriter { return Call(sparse_conv2d_op_, args, Attrs(attrs)); } -private: + private: const Op &sparse_conv2d_op_; DLDevice dev_cpu0_; String layout_; @@ -289,7 +290,8 @@ class Conv2dToSparseConv2dMutator2 : public ExprRewriter { double sparse_thresh_; }; // class Conv2dToSparseConv2dMutator2 -Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int blockH, int blockW, double sparse_thresh) { +Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, + int blockH, int blockW, double sparse_thresh) { auto rewriter = Conv2dToSparseConv2dMutator2(layout, kernel_size, blockH, blockW, sparse_thresh); return PostOrderRewrite(e, &rewriter); } @@ -297,12 +299,15 @@ Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int b namespace transform { +// Convert a model with seperate weight info (already sparsified). Pass Conv2dToSparse(const Array& weight_name, const Array>& weight_shape, const String& layout, int kernel_size) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { // Remove FreeVar warnings - auto f0 = Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); + auto f0 = Downcast( + Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size) + ); Array sparse_params = FreeVars(f0); auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); @@ -317,10 +322,14 @@ Pass Conv2dToSparse(const Array& weight_name, const Array pass_func = [=](Function f, IRModule m, PassContext pc) { - auto f0 = Downcast(Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh)); + auto f0 = Downcast( + Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh) + ); return f0; }; return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2", {"DeadCodeElimination"}); From f7574196fb4f0adb2f57868aad2cfa31a261cd70 Mon Sep 17 00:00:00 2001 From: Tantalus Date: Thu, 5 Aug 2021 17:11:38 +0800 Subject: [PATCH 08/21] format fix --- src/relay/transforms/convert_sparse_conv2d.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index c21d8ca81cb7..c12d346be066 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -306,8 +306,7 @@ Pass Conv2dToSparse(const Array& weight_name, const Array( - Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size) - ); + Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); Array sparse_params = FreeVars(f0); auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); @@ -328,8 +327,7 @@ Pass Conv2dToSparse2(const String& layout, int kernel_size, runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { auto f0 = Downcast( - Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh) - ); + Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh)); return f0; }; return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2", {"DeadCodeElimination"}); From bed9f4016686dd94bbc81e0d1701b612ebd1e4dd Mon Sep 17 00:00:00 2001 From: Tantalus Date: Thu, 5 Aug 2021 17:32:43 +0800 Subject: [PATCH 09/21] format fix --- include/tvm/relay/attrs/nn.h | 5 +- src/relay/transforms/convert_sparse_conv2d.cc | 109 ++++++++---------- 2 files changed, 49 insertions(+), 65 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index f4eff9120e4a..4dd1c3d09208 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1054,8 +1054,9 @@ struct SparseConv2DAttrs : public tvm::AttrsNode { "Dimension ordering of input data. Can be 'NCHW', 'NHWC'" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively."); - TVM_ATTR_FIELD(kernel_size).set_default(1).describe( - "Kernel size for SparseConv2D, 1x1 or 3x3. "); + TVM_ATTR_FIELD(kernel_size) + .set_default(1) + .describe("Kernel size for SparseConv2D, 1x1 or 3x3. "); } }; diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index c12d346be066..4393b533d7a8 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -73,8 +73,8 @@ TVM_REGISTER_GLOBAL("relay.analysis.search_conv2d_op_weight").set_body_typed(Sea class Conv2dToSparseConv2dMutator : public ExprRewriter { public: Conv2dToSparseConv2dMutator(const Array& weight_name, - const Array>& weight_shape, - const String& layout, int kernel_size) + const Array>& weight_shape, const String& layout, + int kernel_size) : conv2d_op_(Op::Get("nn.conv2d")), sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")) { ICHECK_EQ(weight_name.size(), weight_shape.size()); layout_ = layout; @@ -133,91 +133,79 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { }; // class Conv2dToSparseConv2dAlter Expr Conv2dToSparse(const Expr& e, const Array& weight_name, - const Array>& weight_shape, - const String& layout, int kernel_size) { + const Array>& weight_shape, const String& layout, + int kernel_size) { auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout, kernel_size); return PostOrderRewrite(e, &rewriter); } - template -auto unpack_to_tuple_internal(elemTy *arr, std::index_sequence) { +auto unpack_to_tuple_internal(elemTy* arr, std::index_sequence) { return std::make_tuple(arr[Is]...); } template -auto unpack_to_tuple(elemTy *arr) { +auto unpack_to_tuple(elemTy* arr) { return unpack_to_tuple_internal(arr, std::make_index_sequence{}); } struct Range { size_t dim; - explicit Range(size_t d): dim(d) {} + explicit Range(size_t d) : dim(d) {} struct iterpoint { size_t val, lim; - iterpoint(size_t v1, size_t v2): val(v1), lim(v2) {} + iterpoint(size_t v1, size_t v2) : val(v1), lim(v2) {} - size_t operator*() const { - return val; - } + size_t operator*() const { return val; } - iterpoint operator/(const iterpoint &rhs) const { + iterpoint operator/(const iterpoint& rhs) const { return iterpoint(val * rhs.lim + rhs.val, lim * rhs.lim); } }; struct iterator { size_t val, lim; - iterator(size_t v1, size_t v2): val(v1), lim(v2) {} + iterator(size_t v1, size_t v2) : val(v1), lim(v2) {} - bool operator!=(const iterator &rhs) const { - return val != rhs.val; - } + bool operator!=(const iterator &rhs) const { return val != rhs.val; } - void operator++() { - ++val; - } + void operator++() { ++val; } - iterpoint operator*() const { - return iterpoint(val, lim); - } + iterpoint operator*() const { return iterpoint(val, lim); } }; - iterator begin() { - return iterator(0, dim); - } + iterator begin() { return iterator(0, dim); } - iterator end() { - return iterator(dim, dim); - } + iterator end() { return iterator(dim, dim); } }; - // Mutate ```nn.conv2d``` to ```nn.sparse_conv2d``` class Conv2dToSparseConv2dMutator2 : public ExprRewriter { public: - Conv2dToSparseConv2dMutator2(const String& layout, int kernel_size, - int blockH, int blockW, double sparse_thresh) - : sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")), dev_cpu0_{DLDeviceType::kDLCPU, 0}, - layout_(layout), kernel_size_(kernel_size), - blockH_(blockH), blockW_(blockW), sparse_thresh_(sparse_thresh) { } + Conv2dToSparseConv2dMutator2(const String& layout, int kernel_size, int blockH, int blockW, + double sparse_thresh) + : sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")), + dev_cpu0_{DLDeviceType::kDLCPU, 0}, + layout_(layout), + kernel_size_(kernel_size), + blockH_(blockH), + blockW_(blockW), + sparse_thresh_(sparse_thresh) {} Expr Rewrite_(const CallNode* pre, const Expr& post) override { // check op type & attrs const auto pre_attrs = pre->attrs.as(); - if (!pre_attrs - || pre_attrs->data_layout != layout_ - || pre_attrs->strides[0].as()->value != 1 - || pre_attrs->kernel_size[0].as()->value != kernel_size_) + if (!pre_attrs || pre_attrs->data_layout != layout_ || + pre_attrs->strides[0].as()->value != 1 || + pre_attrs->kernel_size[0].as()->value != kernel_size_) return post; // check constant weight const auto pre_weight_node = pre->args[1].as(); - if (!pre_weight_node) - return post; + if (!pre_weight_node) return post; // check weight dtype & shape - auto &&pre_weight = pre_weight_node->data; + auto&& pre_weight = pre_weight_node->data; auto dtype = pre_weight.DataType(), itype = runtime::DataType::Int(32); ICHECK(dtype.code() == DataType::kFloat && dtype.bits() == 32); // float32 only auto pre_weight_shape = unpack_to_tuple<4>(pre_weight.Shape().data()); @@ -235,8 +223,7 @@ class Conv2dToSparseConv2dMutator2 : public ExprRewriter { if (layout_ == "NHWC") { std::vector tmp(pre_weight_data.size()); for (auto i : Range(CO)) - for (auto j : Range(CI)) - tmp[*(i / j)] = pre_weight_data[*(j / i)]; + for (auto j : Range(CI)) tmp[*(i / j)] = pre_weight_data[*(j / i)]; std::swap(tmp, pre_weight_data); } // convert to BSR @@ -247,11 +234,11 @@ class Conv2dToSparseConv2dMutator2 : public ExprRewriter { for (auto bw : Range(CI / blockW_)) { int cntnnz = 0; for (auto i : Range(blockH_)) - for (auto j : Range(blockW_)) { - auto tmp = pre_weight_data[*(bh / i / bw / j)]; - if (tmp) cntnnz++; - block[*(i / j)] = tmp; - } + for (auto j : Range(blockW_)) { + auto tmp = pre_weight_data[*(bh / i / bw / j)]; + if (tmp) cntnnz++; + block[*(i / j)] = tmp; + } if (cntnnz) { wdata.insert(wdata.end(), block.begin(), block.end()); windices.push_back(*bw); @@ -272,10 +259,8 @@ class Conv2dToSparseConv2dMutator2 : public ExprRewriter { weight_indptr.CopyFromBytes(windptr.data(), windptr.size() * sizeof(int32_t)); // construct return call - auto args = runtime::Array { - post.as()->args[0], - Constant(weight_data), Constant(weight_indices), Constant(weight_indptr) - }; + auto args = runtime::Array{post.as()->args[0], Constant(weight_data), + Constant(weight_indices), Constant(weight_indptr)}; auto attrs = make_object(); attrs->layout = layout_; attrs->kernel_size = kernel_size_; @@ -283,20 +268,19 @@ class Conv2dToSparseConv2dMutator2 : public ExprRewriter { } private: - const Op &sparse_conv2d_op_; + const Op& sparse_conv2d_op_; DLDevice dev_cpu0_; String layout_; int kernel_size_, blockH_, blockW_; double sparse_thresh_; }; // class Conv2dToSparseConv2dMutator2 -Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, - int blockH, int blockW, double sparse_thresh) { +Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int blockH, int blockW, + double sparse_thresh) { auto rewriter = Conv2dToSparseConv2dMutator2(layout, kernel_size, blockH, blockW, sparse_thresh); return PostOrderRewrite(e, &rewriter); } - namespace transform { // Convert a model with seperate weight info (already sparsified). @@ -305,8 +289,8 @@ Pass Conv2dToSparse(const Array& weight_name, const Array pass_func = [=](Function f, IRModule m, PassContext pc) { // Remove FreeVar warnings - auto f0 = Downcast( - Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); + auto f0 = + Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); Array sparse_params = FreeVars(f0); auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); @@ -320,14 +304,13 @@ Pass Conv2dToSparse(const Array& weight_name, const Array pass_func = [=](Function f, IRModule m, PassContext pc) { auto f0 = Downcast( - Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh)); + Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh)); return f0; }; return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2", {"DeadCodeElimination"}); From 14e2934aa34cdf9ff4e4833b7c66ef10fc949ef8 Mon Sep 17 00:00:00 2001 From: Tantalus Date: Thu, 5 Aug 2021 17:35:08 +0800 Subject: [PATCH 10/21] format fix --- src/relay/transforms/convert_sparse_conv2d.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index 4393b533d7a8..d6e87dbd0ffd 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -289,7 +289,7 @@ Pass Conv2dToSparse(const Array& weight_name, const Array pass_func = [=](Function f, IRModule m, PassContext pc) { // Remove FreeVar warnings - auto f0 = + auto f0 = Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); Array sparse_params = FreeVars(f0); auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); From ef7ae5c4641677128b819d9f5da160a630b63202 Mon Sep 17 00:00:00 2001 From: Tantalus Date: Thu, 5 Aug 2021 17:42:43 +0800 Subject: [PATCH 11/21] format fix --- src/relay/transforms/convert_sparse_conv2d.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index d6e87dbd0ffd..39f19085c336 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -168,7 +168,7 @@ struct Range { size_t val, lim; iterator(size_t v1, size_t v2) : val(v1), lim(v2) {} - bool operator!=(const iterator &rhs) const { return val != rhs.val; } + bool operator!=(const iterator& rhs) const { return val != rhs.val; } void operator++() { ++val; } @@ -260,7 +260,7 @@ class Conv2dToSparseConv2dMutator2 : public ExprRewriter { // construct return call auto args = runtime::Array{post.as()->args[0], Constant(weight_data), - Constant(weight_indices), Constant(weight_indptr)}; + Constant(weight_indices), Constant(weight_indptr)}; auto attrs = make_object(); attrs->layout = layout_; attrs->kernel_size = kernel_size_; From bccacc6e665f4d6c6a5fb7238ff019f94bc969e3 Mon Sep 17 00:00:00 2001 From: Tantalus Date: Fri, 6 Aug 2021 02:08:37 +0800 Subject: [PATCH 12/21] format fix --- python/tvm/relay/analysis/sparse_conv2d.py | 4 +- .../relay/data_dep_optimization/bsr_conv2d.py | 4 +- python/tvm/relay/op/nn/_nn.py | 6 +- python/tvm/topi/nn/sparse.py | 4 +- python/tvm/topi/x86/sparse.py | 69 +++++++++---------- 5 files changed, 48 insertions(+), 39 deletions(-) diff --git a/python/tvm/relay/analysis/sparse_conv2d.py b/python/tvm/relay/analysis/sparse_conv2d.py index 0bb865db536d..1862ded831f6 100644 --- a/python/tvm/relay/analysis/sparse_conv2d.py +++ b/python/tvm/relay/analysis/sparse_conv2d.py @@ -54,7 +54,9 @@ def _search_conv2d_op_weight(expr): return _ffi_api.search_conv2d_op_weight(expr) -def process_params(expr, params, block_size, sparsity_threshold, layout, kernel_size, reg_task_input=True): +def process_params( + expr, params, block_size, sparsity_threshold, layout, kernel_size, reg_task_input=True +): """Process parameters of conv2d from dense to sparse. Parameters diff --git a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py index 2133af5711b4..b97fbe44d7cb 100644 --- a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py +++ b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py @@ -52,7 +52,9 @@ def convert(func, params, blocksize, sparsity_threshold, layout="NHWC", kernel_s weight_info = process_params(func, params, blocksize, sparsity_threshold, layout, kernel_size) new_func = _run_opt_pass( func, - relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout, kernel_size), + relay.transform.Conv2dToSparse( + weight_info.weight_name, weight_info.weight_shape, layout, kernel_size + ), ) return new_func, params diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c6a6c6ca8221..da2985cb441c 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -198,7 +198,11 @@ def compute_sparse_transpose(attrs, inputs, out_type): @reg.register_compute("nn.sparse_conv2d") def compute_sparse_conv2d(attrs, inputs, out_type): """Compute definition of sparse_conv2d""" - return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"], attrs["kernel_size"])] + return [ + topi.nn.sparse_conv2d( + inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"], attrs["kernel_size"] + ) + ] reg.register_strategy("nn.sparse_conv2d", strategy.sparse_conv2d_strategy) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index c6b50fc3c8c5..e577104c3ddc 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -566,7 +566,9 @@ def _compute_block(i, nb_j, j, h, w): # pylint: disable=C0103 ) -def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC", kernel_size=1): +def sparse_conv2d( + dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC", kernel_size=1 +): """ Computes sparse-conv2d(1*1) of ``data`` and ``(weight_data, weight_indices, weight_indptr)`` diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index be24d7ef5d96..7434e5dda4b7 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -70,36 +70,39 @@ def spconv2d_3x3_nhwc(cfg, Data, Wdat, Wind, Wptr, layout="NHWC"): nElems, bsrR, bsrC = [i.value for i in Wdat.shape] CO = (Wptr.shape[0].value - 1) * bsrR - Y, X, K = N*H*W, CO, 9*CI + Y, X, K = N * H * W, CO, 9 * CI cfg.define_split("tile_y", Y, num_outputs=3) cfg.define_split("tile_x", X // bsrR, num_outputs=2) cfg.add_flop(Y * (nElems * bsrC * bsrR * 2 - X)) if cfg.is_fallback: cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 160, 8]) cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 4]) - - idxsplit = lambda x,y: reduce(lambda a,b: a[:-1]+[a[-1]%b,a[-1]//b], y, [x]) + + idxsplit = lambda x, y: reduce(lambda a, b: a[:-1] + [a[-1] % b, a[-1] // b], y, [x]) @partial(te.compute, (Y, K), name="Im2Col") def Im2Col(row, col): jw, jh, jn = idxsplit(row, [W, H]) jc, kw, kh = idxsplit(col, [CI, 3]) ih, iw = jh + kh - 1, jw + kw - 1 - return tir.if_then_else( - tir.all(0 <= ih, ih < H, 0 <= iw, iw < W), - Data[jn, ih, iw, jc], 0) - + return tir.if_then_else(tir.all(0 <= ih, ih < H, 0 <= iw, iw < W), Data[jn, ih, iw, jc], 0) + @partial(te.compute, (Y, X // bsrR, bsrR, bsrC), name="CC") def CC(drow, wrow, brow, bcol): - row_start, row_end = Wptr[wrow], Wptr[wrow+1] + row_start, row_end = Wptr[wrow], Wptr[wrow + 1] elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") elem = row_start + elem_idx - return te.sum(Im2Col[drow, Wind[elem]*bsrC + bcol] * Wdat[elem, brow, bcol], axis=elem_idx) + return te.sum( + Im2Col[drow, Wind[elem]*bsrC + bcol] * Wdat[elem, brow, bcol], axis=elem_idx + ) k = te.reduce_axis((0, bsrC), name="k") - C = te.compute((Y, X), + C = te.compute( + (Y, X), lambda y, x: te.sum(CC[y, x // bsrR, x % bsrR, k], axis=k), - name="C", tag="conv3x3_spNHWC") + name="C", + tag="conv3x3_spNHWC", + ) return reshape(C, (N, H, W, CO)) @@ -107,13 +110,13 @@ def CC(drow, wrow, brow, bcol): def schedule_spconv2d_3x3_nhwc(cfg, outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) - + def _callback(op): if op.tag == "conv3x3_spNHWC": C = op - CC, = op.input_tensors + (CC,) = op.input_tensors Wptr, Wind, Im2Col, Wdat = CC.op.input_tensors - Data, = Im2Col.op.input_tensors + (Data,) = Im2Col.op.input_tensors bsrR = CC.shape[-2].value CI = Data.shape[-1].value @@ -134,12 +137,11 @@ def _callback(op): s[CC].unroll(c) s[CC].vectorize(r) s[CC].unroll(yi) - + s[Im2Col].compute_at(s[C], yo) yi, k = s[Im2Col].op.axis ko, ki = s[Im2Col].split(k, factor=CI) s[Im2Col].vectorize(ki) - #s[Im2Col].unroll(yi) traverse_inline(s, outs[0].op, _callback) return s @@ -151,28 +153,25 @@ def spconv2d_3x3_nchw(cfg, Data, Wdat, Wind, Wptr, layout="NCHW"): NNZ, VL, bsrC = [i.value for i in Wdat.shape] CO = (Wptr.shape[0].value - 1) * VL assert bsrC == 1 - - cfg.add_flop(N*H*W * (NNZ * VL * bsrC * 2 - CO)) - cfg.define_split("tile_hw", H*W, num_outputs=3) - cfg.define_split("tile_ckk", CI*9, num_outputs=3) - - @partial(te.compute, (N, CI*3*3, H*W), name="im2col") + + cfg.add_flop(N * H * W * (NNZ * VL * bsrC * 2 - CO)) + cfg.define_split("tile_hw", H * W, num_outputs=3) + cfg.define_split("tile_ckk", CI * 9, num_outputs=3) + + @partial(te.compute, (N, CI * 3 * 3, H * W), name="im2col") def Im2Col(n, ckk, hw): jh, jw = hw // W, hw % W ic, kh, kw = ckk // 9, ckk // 3 % 3, ckk % 3 ih, iw = jh + kh - 1, jw + kw - 1 - return tir.if_then_else( - tir.all(0 <= ih, ih < H, 0 <= iw, iw < W), - Data[n, ic, ih, iw], 0) - - @partial(te.compute, (N, CO // VL, VL, bsrC, H*W), name="CC", tag="conv3x3_spNCHW") + return tir.if_then_else(tir.all(0 <= ih, ih < H, 0 <= iw, iw < W), Data[n, ic, ih, iw], 0) + + @partial(te.compute, (N, CO // VL, VL, bsrC, H * W), name="CC", tag="conv3x3_spNCHW") def CC(n, fo, fi, k, hw): - row_start, row_end = Wptr[fo], Wptr[fo+1] + row_start, row_end = Wptr[fo], Wptr[fo + 1] elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") elem = row_start + elem_idx - return te.sum(Im2Col[n, Wind[elem] * bsrC + k, hw] * Wdat[elem, fi, k], - axis=elem_idx) - + return te.sum(Im2Col[n, Wind[elem] * bsrC + k, hw] * Wdat[elem, fi, k], axis=elem_idx) + return reshape(CC, [N, CO, H, W]) @@ -180,15 +179,15 @@ def CC(n, fo, fi, k, hw): def schedule_spconv2d_3x3_nchw(cfg, outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) - + def _callback(op): if op.tag == "conv3x3_spNCHW": CC = op Wptr, Wind, im2col, Wdat = op.input_tensors - Data, = im2col.op.input_tensors - + (Data,) = im2col.op.input_tensors + n, fo, fi, bc, hw = s[CC].op.axis - kk, = s[CC].op.reduce_axis + (kk,) = s[CC].op.reduce_axis hw1, hw2, hw3 = cfg["tile_hw"].apply(s, CC, hw) s[CC].reorder(n, hw1, fo, hw2, kk, fi, bc, hw3) s[CC].unroll(fi) From cd8082787597b5434958cd36cc63e6d9f5fe4bca Mon Sep 17 00:00:00 2001 From: Tantalus Date: Fri, 6 Aug 2021 02:14:13 +0800 Subject: [PATCH 13/21] format fix --- python/tvm/topi/x86/sparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index 7434e5dda4b7..4c844eacfbd9 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -93,7 +93,7 @@ def CC(drow, wrow, brow, bcol): elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") elem = row_start + elem_idx return te.sum( - Im2Col[drow, Wind[elem]*bsrC + bcol] * Wdat[elem, brow, bcol], axis=elem_idx + Im2Col[drow, Wind[elem] * bsrC + bcol] * Wdat[elem, brow, bcol], axis=elem_idx ) k = te.reduce_axis((0, bsrC), name="k") From 04c0e19aee76e378bfd089c428c9bef61cc4b70e Mon Sep 17 00:00:00 2001 From: Tantalus Date: Fri, 6 Aug 2021 02:26:34 +0800 Subject: [PATCH 14/21] format fix --- python/tvm/topi/x86/sparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index 4c844eacfbd9..ddd34d441b3d 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -203,4 +203,4 @@ def _callback(op): s[im2col].vectorize(hw3) traverse_inline(s, outs[0].op, _callback) - return s \ No newline at end of file + return s From 380adb51416e3b1842adb691f2d83d7fa9592c4c Mon Sep 17 00:00:00 2001 From: Tantalus Date: Mon, 9 Aug 2021 02:33:46 +0800 Subject: [PATCH 15/21] format fix --- python/tvm/topi/x86/sparse.py | 198 ++++++++++++++++++---------------- 1 file changed, 106 insertions(+), 92 deletions(-) diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index ddd34d441b3d..1df2fe053f44 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -16,8 +16,8 @@ # under the License. """sparse_dense schedule on x86""" -from tvm import te, tir, autotvm from functools import partial, reduce +from tvm import te, tir, autotvm from ..transform import reshape from ..utils import traverse_inline, get_const_int @@ -65,140 +65,154 @@ def _callback(op): @autotvm.register_topi_compute("conv3x3_spNHWC.x86") -def spconv2d_3x3_nhwc(cfg, Data, Wdat, Wind, Wptr, layout="NHWC"): - N, H, W, CI = [i.value for i in Data.shape] - nElems, bsrR, bsrC = [i.value for i in Wdat.shape] - CO = (Wptr.shape[0].value - 1) * bsrR - - Y, X, K = N * H * W, CO, 9 * CI - cfg.define_split("tile_y", Y, num_outputs=3) - cfg.define_split("tile_x", X // bsrR, num_outputs=2) - cfg.add_flop(Y * (nElems * bsrC * bsrR * 2 - X)) +def spconv2d_3x3_nhwc(cfg, data, wdat, wind, wptr, layout="NHWC"): + """Sparse Conv2d 3x3 compute (NHWC).""" + nsamples, imh, imw, chanin = [i.value for i in data.shape] + nelems, bsrr, bsrc = [i.value for i in wdat.shape] + chanout = (wptr.shape[0].value - 1) * bsrr + + imglen, chanlen = nsamples * imh * imw, 9 * chanin + cfg.define_split("tile_y", imglen, num_outputs=3) + cfg.define_split("tile_x", chanout // bsrr, num_outputs=2) + cfg.add_flop(imglen * (nelems * bsrc * bsrr * 2 - chanout)) if cfg.is_fallback: cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 160, 8]) cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 4]) idxsplit = lambda x, y: reduce(lambda a, b: a[:-1] + [a[-1] % b, a[-1] // b], y, [x]) - @partial(te.compute, (Y, K), name="Im2Col") - def Im2Col(row, col): - jw, jh, jn = idxsplit(row, [W, H]) - jc, kw, kh = idxsplit(col, [CI, 3]) - ih, iw = jh + kh - 1, jw + kw - 1 - return tir.if_then_else(tir.all(0 <= ih, ih < H, 0 <= iw, iw < W), Data[jn, ih, iw, jc], 0) + @partial(te.compute, (imglen, chanlen), name="Im2Col") + def im2col(row, col): + j_w, j_h, j_n = idxsplit(row, [imw, imh]) + j_c, k_w, k_h = idxsplit(col, [chanin, 3]) + i_h, i_w = j_h + k_h - 1, j_w + k_w - 1 + return tir.if_then_else( + tir.all(i_h >= 0, i_h < imh, i_w >= 0, i_w < imw), data[j_n, i_h, i_w, j_c], 0 + ) - @partial(te.compute, (Y, X // bsrR, bsrR, bsrC), name="CC") - def CC(drow, wrow, brow, bcol): - row_start, row_end = Wptr[wrow], Wptr[wrow + 1] + @partial(te.compute, (imglen, chanout // bsrr, bsrr, bsrc), name="CC") + def matmul(drow, wrow, brow, bcol): + row_start, row_end = wptr[wrow], wptr[wrow + 1] elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") elem = row_start + elem_idx return te.sum( - Im2Col[drow, Wind[elem] * bsrC + bcol] * Wdat[elem, brow, bcol], axis=elem_idx + im2col[drow, wind[elem] * bsrc + bcol] * wdat[elem, brow, bcol], axis=elem_idx ) - k = te.reduce_axis((0, bsrC), name="k") - C = te.compute( - (Y, X), - lambda y, x: te.sum(CC[y, x // bsrR, x % bsrR, k], axis=k), + sum_bsrc = te.reduce_axis((0, bsrc), name="k") + ret = te.compute( + (imglen, chanout), + lambda y, x: te.sum(matmul[y, x // bsrr, x % bsrr, sum_bsrc], axis=sum_bsrc), name="C", tag="conv3x3_spNHWC", ) - return reshape(C, (N, H, W, CO)) + return reshape(ret, (nsamples, imh, imw, chanout)) @autotvm.register_topi_schedule("conv3x3_spNHWC.x86") def schedule_spconv2d_3x3_nhwc(cfg, outs): + """Sparse Conv2d 3x3 schedule (NHWC).""" outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) def _callback(op): if op.tag == "conv3x3_spNHWC": - C = op - (CC,) = op.input_tensors - Wptr, Wind, Im2Col, Wdat = CC.op.input_tensors - (Data,) = Im2Col.op.input_tensors - bsrR = CC.shape[-2].value - CI = Data.shape[-1].value - - y, x = s[C].op.axis - yt, yo, yi = cfg["tile_y"].apply(s, C, y) - xo, xi = s[C].split(x, factor=bsrR) - xt, xo = cfg["tile_x"].apply(s, C, xo) - (k,) = s[C].op.reduce_axis - s[C].reorder(yt, xt, yo, xo, yi, xi, k) - s[C].unroll(k) - s[C].vectorize(xi) - s[C].unroll(yi) - - s[CC].compute_at(s[C], xo) - yi, xi, r, c = s[CC].op.axis - (k,) = s[CC].op.reduce_axis - s[CC].reorder(xi, k, yi, r, c) - s[CC].unroll(c) - s[CC].vectorize(r) - s[CC].unroll(yi) - - s[Im2Col].compute_at(s[C], yo) - yi, k = s[Im2Col].op.axis - ko, ki = s[Im2Col].split(k, factor=CI) - s[Im2Col].vectorize(ki) + (matmul,) = op.input_tensors + wptr, wind, im2col, wdat = matmul.op.input_tensors + (data,) = im2col.op.input_tensors + bsrr = matmul.shape[-2].value + chanin = data.shape[-1].value + + mm_y, mm_x = s[op].op.axis + y_t, y_o, y_i = cfg["tile_y"].apply(s, op, mm_y) + x_o, x_i = s[op].split(mm_x, factor=bsrr) + x_t, x_o = cfg["tile_x"].apply(s, op, x_o) + (sum_ax,) = s[op].op.reduce_axis + s[op].reorder(y_t, x_t, y_o, x_o, y_i, x_i, sum_ax) + s[op].unroll(sum_ax) + s[op].vectorize(x_i) + s[op].unroll(y_i) + + s[matmul].compute_at(s[op], x_o) + y_i, x_i, bsrr, bsrc = s[matmul].op.axis + (sum_ax,) = s[matmul].op.reduce_axis + s[matmul].reorder(x_i, sum_ax, y_i, bsrr, bsrc) + s[matmul].unroll(bsrc) + s[matmul].vectorize(bsrr) + s[matmul].unroll(y_i) + + s[im2col].compute_at(s[op], y_o) + y_i, sum_ax = s[im2col].op.axis + k_o, k_i = s[im2col].split(sum_ax, factor=chanin) + s[im2col].vectorize(k_i) traverse_inline(s, outs[0].op, _callback) return s @autotvm.register_topi_compute("conv3x3_spNCHW.x86") -def spconv2d_3x3_nchw(cfg, Data, Wdat, Wind, Wptr, layout="NCHW"): - N, CI, H, W = [i.value for i in Data.shape] - NNZ, VL, bsrC = [i.value for i in Wdat.shape] - CO = (Wptr.shape[0].value - 1) * VL - assert bsrC == 1 - - cfg.add_flop(N * H * W * (NNZ * VL * bsrC * 2 - CO)) - cfg.define_split("tile_hw", H * W, num_outputs=3) - cfg.define_split("tile_ckk", CI * 9, num_outputs=3) - - @partial(te.compute, (N, CI * 3 * 3, H * W), name="im2col") - def Im2Col(n, ckk, hw): - jh, jw = hw // W, hw % W - ic, kh, kw = ckk // 9, ckk // 3 % 3, ckk % 3 - ih, iw = jh + kh - 1, jw + kw - 1 - return tir.if_then_else(tir.all(0 <= ih, ih < H, 0 <= iw, iw < W), Data[n, ic, ih, iw], 0) - - @partial(te.compute, (N, CO // VL, VL, bsrC, H * W), name="CC", tag="conv3x3_spNCHW") - def CC(n, fo, fi, k, hw): - row_start, row_end = Wptr[fo], Wptr[fo + 1] +def spconv2d_3x3_nchw(cfg, data, wdat, wind, wptr, layout="NCHW"): + """Sparse Conv2d 3x3 compute (NCHW).""" + nsamples, chanin, imgh, imgw = [i.value for i in data.shape] + nelems, veclen, bsrc = [i.value for i in wdat.shape] + chanout = (wptr.shape[0].value - 1) * veclen + assert bsrc == 1 + + cfg.add_flop(nsamples * imgh * imgw * (nelems * veclen * bsrc * 2 - chanout)) + cfg.define_split("tile_hw", imgh * imgw, num_outputs=3) + cfg.define_split("tile_ckk", chanin * 9, num_outputs=3) + + @partial(te.compute, (nsamples, chanin * 3 * 3, imgh * imgw), name="im2col") + def im2col(nsamples, ckk, imglen): + j_h, j_w = imglen // imgw, imglen % imgw + i_c, k_h, k_w = ckk // 9, ckk // 3 % 3, ckk % 3 + i_h, i_w = j_h + k_h - 1, j_w + k_w - 1 + return tir.if_then_else( + tir.all(i_h >= 0, i_h < imgh, i_w >= 0, i_w < imgw), data[nsamples, i_c, i_h, i_w], 0 + ) + + @partial( + te.compute, + (nsamples, chanout // veclen, veclen, bsrc, imgh * imgw), + name="CC", + tag="conv3x3_spNCHW", + ) + def matmul(nsamples, f_o, f_i, bsrk, imglen): + row_start, row_end = wptr[f_o], wptr[f_o + 1] elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") elem = row_start + elem_idx - return te.sum(Im2Col[n, Wind[elem] * bsrC + k, hw] * Wdat[elem, fi, k], axis=elem_idx) + return te.sum( + im2col[nsamples, wind[elem] * bsrc + bsrk, imglen] * wdat[elem, f_i, bsrk], + axis=elem_idx, + ) - return reshape(CC, [N, CO, H, W]) + return reshape(matmul, [nsamples, chanout, imgh, imgw]) @autotvm.register_topi_schedule("conv3x3_spNCHW.x86") def schedule_spconv2d_3x3_nchw(cfg, outs): + """Sparse Conv2d 3x3 schedule (NCHW).""" outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) def _callback(op): if op.tag == "conv3x3_spNCHW": - CC = op - Wptr, Wind, im2col, Wdat = op.input_tensors - (Data,) = im2col.op.input_tensors - - n, fo, fi, bc, hw = s[CC].op.axis - (kk,) = s[CC].op.reduce_axis - hw1, hw2, hw3 = cfg["tile_hw"].apply(s, CC, hw) - s[CC].reorder(n, hw1, fo, hw2, kk, fi, bc, hw3) - s[CC].unroll(fi) - s[CC].unroll(bc) - s[CC].vectorize(hw3) - - s[im2col].compute_at(s[CC], hw1) - n, ckk, hw = s[im2col].op.axis + wptr, wind, im2col, wdat = op.input_tensors + (data,) = im2col.op.input_tensors + + n_samples, f_o, f_i, b_c, imglen = s[op].op.axis + (sum_ax,) = s[op].op.reduce_axis + hw1, hw2, hw3 = cfg["tile_hw"].apply(s, op, imglen) + s[op].reorder(n_samples, hw1, f_o, hw2, sum_ax, f_i, b_c, hw3) + s[op].unroll(f_i) + s[op].unroll(b_c) + s[op].vectorize(hw3) + + s[im2col].compute_at(s[op], hw1) + n_samples, ckk, imglen = s[im2col].op.axis ckk1, ckk2, ckk3 = cfg["tile_ckk"].apply(s, im2col, ckk) - hw2, hw3 = s[im2col].split(hw, factor=cfg["tile_hw"].size[-1]) - s[im2col].reorder(n, ckk1, ckk2, hw2, ckk3, hw3) + hw2, hw3 = s[im2col].split(imglen, factor=cfg["tile_hw"].size[-1]) + s[im2col].reorder(n_samples, ckk1, ckk2, hw2, ckk3, hw3) s[im2col].unroll(ckk3) s[im2col].vectorize(hw3) From 3801554a4442464f3bc5fb3a9f53f167d60c0b6e Mon Sep 17 00:00:00 2001 From: Tantalus Date: Mon, 9 Aug 2021 02:42:53 +0800 Subject: [PATCH 16/21] format fix --- python/tvm/topi/x86/sparse.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index 1df2fe053f44..48ec233fa4bb 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -67,6 +67,7 @@ def _callback(op): @autotvm.register_topi_compute("conv3x3_spNHWC.x86") def spconv2d_3x3_nhwc(cfg, data, wdat, wind, wptr, layout="NHWC"): """Sparse Conv2d 3x3 compute (NHWC).""" + assert layout == "NHWC" nsamples, imh, imw, chanin = [i.value for i in data.shape] nelems, bsrr, bsrc = [i.value for i in wdat.shape] chanout = (wptr.shape[0].value - 1) * bsrr @@ -118,7 +119,8 @@ def schedule_spconv2d_3x3_nhwc(cfg, outs): def _callback(op): if op.tag == "conv3x3_spNHWC": (matmul,) = op.input_tensors - wptr, wind, im2col, wdat = matmul.op.input_tensors + # wptr, wind, im2col, wdat + _, _, im2col, _ = matmul.op.input_tensors (data,) = im2col.op.input_tensors bsrr = matmul.shape[-2].value chanin = data.shape[-1].value @@ -143,7 +145,7 @@ def _callback(op): s[im2col].compute_at(s[op], y_o) y_i, sum_ax = s[im2col].op.axis - k_o, k_i = s[im2col].split(sum_ax, factor=chanin) + _, k_i = s[im2col].split(sum_ax, factor=chanin) s[im2col].vectorize(k_i) traverse_inline(s, outs[0].op, _callback) @@ -156,7 +158,7 @@ def spconv2d_3x3_nchw(cfg, data, wdat, wind, wptr, layout="NCHW"): nsamples, chanin, imgh, imgw = [i.value for i in data.shape] nelems, veclen, bsrc = [i.value for i in wdat.shape] chanout = (wptr.shape[0].value - 1) * veclen - assert bsrc == 1 + assert bsrc == 1 and layout == "NCHW" cfg.add_flop(nsamples * imgh * imgw * (nelems * veclen * bsrc * 2 - chanout)) cfg.define_split("tile_hw", imgh * imgw, num_outputs=3) @@ -197,8 +199,8 @@ def schedule_spconv2d_3x3_nchw(cfg, outs): def _callback(op): if op.tag == "conv3x3_spNCHW": - wptr, wind, im2col, wdat = op.input_tensors - (data,) = im2col.op.input_tensors + # wptr, wind, im2col, wdat + _, _, im2col, _ = op.input_tensors n_samples, f_o, f_i, b_c, imglen = s[op].op.axis (sum_ax,) = s[op].op.reduce_axis From 8162d765671010ce831ef6b31e15f54a76d03eb8 Mon Sep 17 00:00:00 2001 From: Tantalus Date: Mon, 9 Aug 2021 16:13:16 +0800 Subject: [PATCH 17/21] use array for sparse conv2d attr --- include/tvm/relay/attrs/nn.h | 4 ++-- src/relay/op/nn/sparse.cc | 4 ++-- src/relay/transforms/convert_sparse_conv2d.cc | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 4dd1c3d09208..00d1b9b93462 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1047,7 +1047,7 @@ struct SparseTransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes for sparse_dense operator */ struct SparseConv2DAttrs : public tvm::AttrsNode { std::string layout; - int kernel_size; + Array kernel_size; TVM_DECLARE_ATTRS(SparseConv2DAttrs, "relay.attrs.SparseConv2DAttrs") { TVM_ATTR_FIELD(layout).set_default("NHWC").describe( @@ -1055,7 +1055,7 @@ struct SparseConv2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively."); TVM_ATTR_FIELD(kernel_size) - .set_default(1) + .set_default(Array {1, 1}) .describe("Kernel size for SparseConv2D, 1x1 or 3x3. "); } }; diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 99fb4699123d..7d21005cb4db 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -274,11 +274,11 @@ bool SparseConv2dRel(const Array& types, int num_inputs, const Attrs& attr } Expr MakeSparseConv2d(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr, - std::string layout, int kernel_size) { + std::string layout, Array kernel_size) { static const Op& op = Op::Get("nn.sparse_conv2d"); auto attrs = make_object(); attrs->layout = std::move(layout); - attrs->kernel_size = kernel_size; + attrs->kernel_size = std::move(kernel_size); return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index 39f19085c336..9ea873f12a77 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -114,7 +114,7 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { Var weight_indptr(prefix + ".indptr", ws_indptr_type); auto attrs = make_object(); attrs->layout = std::move(layout_); - attrs->kernel_size = kernel_size_; + attrs->kernel_size = Array {kernel_size_, kernel_size_}; return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs)); } @@ -263,7 +263,7 @@ class Conv2dToSparseConv2dMutator2 : public ExprRewriter { Constant(weight_indices), Constant(weight_indptr)}; auto attrs = make_object(); attrs->layout = layout_; - attrs->kernel_size = kernel_size_; + attrs->kernel_size = Array {kernel_size_, kernel_size_}; return Call(sparse_conv2d_op_, args, Attrs(attrs)); } From 1b1fbde02b40855e96b9f5910bdf0e533d892812 Mon Sep 17 00:00:00 2001 From: Tantalus Date: Mon, 9 Aug 2021 16:18:15 +0800 Subject: [PATCH 18/21] format fix --- include/tvm/relay/attrs/nn.h | 2 +- src/relay/transforms/convert_sparse_conv2d.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 00d1b9b93462..b1f2badff8b6 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1055,7 +1055,7 @@ struct SparseConv2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively."); TVM_ATTR_FIELD(kernel_size) - .set_default(Array {1, 1}) + .set_default(Array{1, 1}) .describe("Kernel size for SparseConv2D, 1x1 or 3x3. "); } }; diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index 9ea873f12a77..3f2c25e988f9 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -114,7 +114,7 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { Var weight_indptr(prefix + ".indptr", ws_indptr_type); auto attrs = make_object(); attrs->layout = std::move(layout_); - attrs->kernel_size = Array {kernel_size_, kernel_size_}; + attrs->kernel_size = Array{kernel_size_, kernel_size_}; return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs)); } @@ -263,7 +263,7 @@ class Conv2dToSparseConv2dMutator2 : public ExprRewriter { Constant(weight_indices), Constant(weight_indptr)}; auto attrs = make_object(); attrs->layout = layout_; - attrs->kernel_size = Array {kernel_size_, kernel_size_}; + attrs->kernel_size = Array{kernel_size_, kernel_size_}; return Call(sparse_conv2d_op_, args, Attrs(attrs)); } From 1a9597e87e7ae9e402f3d93661633711ae60e6ad Mon Sep 17 00:00:00 2001 From: Tantalus Date: Tue, 10 Aug 2021 22:33:55 +0800 Subject: [PATCH 19/21] fixup 1x1 tests; new 3x3 tests --- .../relay/data_dep_optimization/bsr_conv2d.py | 34 ++++++++++ python/tvm/relay/op/strategy/x86.py | 12 ++-- python/tvm/relay/transform/transform.py | 20 ++++++ .../relay/test_sparse_conv2d_convert.py | 63 +++++++++++++++++++ 4 files changed, 121 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py index b97fbe44d7cb..20e01da1493e 100644 --- a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py +++ b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py @@ -58,3 +58,37 @@ def convert(func, params, blocksize, sparsity_threshold, layout="NHWC", kernel_s ) return new_func, params + + +def convert2(func, params, blocksize, sparsity_threshold, layout, kernel_size): + """Convert a freezed conv2d func to block sparse + + Parameters + ---------- + func : relay.Expr + Expr will be optimized to sparse operation, with params freezed + params : Dict[Srting, tvm.nd.array] + Parameters of the Expr (not used in this pass) + blocksize : Tuple(int, int) + Blocksize for BSR matrix + sparsity_threshold : float + Minimal sparsity requirement for converting. + If weight sparsity is lower than this threshold, + the dense operation will be kept. + layout : str + layout of network + kernel_size : int + kernel size of the conv2d, for filtering + + Returns + ------- + new_func: relay.Expr + Mutated Expr with sparse operations + + params: Dict[Srting, tvm.nd.array] + New params with BSR matrix for mutated Expr (not modified) + """ + new_func = _run_opt_pass( + func, relay.transform.Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold) + ) + return new_func, params diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 710d545f69f5..e3293bfdf3bd 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -319,9 +319,7 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): # or packed layouts. if layout == "NCDHW": strategy.add_implementation( - wrap_compute_conv3d(topi.nn.conv3d_ncdhw), - naive_schedule, - name="conv3d_ncdhw.x86", + wrap_compute_conv3d(topi.nn.conv3d_ncdhw), naive_schedule, name="conv3d_ncdhw.x86", ) elif layout == "NDHWC": strategy.add_implementation( @@ -440,9 +438,7 @@ def matmul_strategy_cpu(attrs, inputs, out_type, target): "Recommend to use cblas/mkl/mkldnn for better performance." ) strategy.add_implementation( - wrap_compute_matmul(topi.nn.matmul), - naive_schedule, - name="matmul.generic", + wrap_compute_matmul(topi.nn.matmul), naive_schedule, name="matmul.generic", ) return strategy @@ -569,13 +565,13 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target): def sparse_conv2d_strategy_cpu(attrs, inputs, out_type, target): """sparse conv2d x86 strategy""" strategy = _op.OpStrategy() - if attrs["kernel_size"] == 1: + if attrs["kernel_size"][0] == 1: strategy.add_implementation( wrap_compute_sparse_conv2d(topi.nn.sparse_conv2d), wrap_topi_schedule(topi.generic.schedule_sparse_conv2d), name="sparse_conv2d.generic", ) - elif attrs["kernel_size"] == 3: + elif attrs["kernel_size"][0] == 3: if attrs["layout"] == "NHWC": strategy.add_implementation( wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nhwc), diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 174a0c511633..9a7857a01fe6 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1116,6 +1116,26 @@ def Conv2dToSparse(weight_name, weight_shape, layout, kernel_size): return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout, kernel_size) +def Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold): + """ + Rewrite freezed ```nn.conv2d``` operation to ```nn.sparse_conv2d``` + + Parameters + ---------- + layout : str + layout of data + + kernel_size : int + kernel size of conv2d + + Returns + ------- + ret : tvm.transform.Pass + The registered DenseToSparse pass. + """ + return _ffi_api.Conv2dToSparse2(layout, kernel_size, *blocksize, sparsity_threshold) + + def SimplifyFCTranspose(target_weight_name): """ Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)``` diff --git a/tests/python/relay/test_sparse_conv2d_convert.py b/tests/python/relay/test_sparse_conv2d_convert.py index 0af78fc033ac..045462475ee1 100644 --- a/tests/python/relay/test_sparse_conv2d_convert.py +++ b/tests/python/relay/test_sparse_conv2d_convert.py @@ -25,6 +25,7 @@ from tvm.ir import IRModule from tvm import relay from tvm.topi.sparse.utils import random_bsr_matrix +from tvm.relay.build_module import bind_params_by_name def run_func(func, params, x): @@ -100,6 +101,68 @@ def test_bsr_sparse_conv2d_nhwc(): np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) +def test_bsr_sparse_conv2d_3x3_nchw(): + data = relay.var("data", shape=(1, 64, 32, 32), dtype="float32") + x = relay.nn.relu(data) + w = relay.var("weight", shape=(128, 64, 3, 3), dtype="float32") + y = relay.nn.conv2d( + x, w, channels=128, kernel_size=3, padding=1, data_layout="NCHW", kernel_layout="OIHW" + ) + z = relay.nn.relu(y) + func = relay.Function(relay.analysis.free_vars(z), z) + + params = { + "weight": tvm.nd.array( + np.array(random_bsr_matrix(128, 64 * 9, 16, 1, 0.1, "float32").todense()).reshape( + 128, 64, 3, 3 + ) + ) + } + + x_np = np.random.randn(1, 64, 32, 32).astype("float32") + # dense output + dense_output = run_func(func, params, x_np) + # sparse + func = bind_params_by_name(func, params) + sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert2( + func, {}, (16, 1), 0.2, "NCHW", 3 + ) + sparse_output = run_func(sparse_func, params, x_np) + np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) + + +def test_bsr_sparse_conv2d_3x3_nhwc(): + data = relay.var("data", shape=(1, 32, 32, 64), dtype="float32") + x = relay.nn.relu(data) + w = relay.var("weight", shape=(3, 3, 64, 128), dtype="float32") + y = relay.nn.conv2d( + x, w, channels=128, kernel_size=3, padding=1, data_layout="NHWC", kernel_layout="HWIO" + ) + z = relay.nn.relu(y) + func = relay.Function(relay.analysis.free_vars(z), z) + + params = { + "weight": tvm.nd.array( + np.array(random_bsr_matrix(128, 64 * 9, 16, 1, 0.1, "float32").todense()).T.reshape( + 3, 3, 64, 128 + ) + ) + } + + x_np = np.random.randn(1, 32, 32, 64).astype("float32") + # dense output + dense_output = run_func(func, params, x_np) + # sparse + func = bind_params_by_name(func, params) + sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert2( + func, {}, (16, 1), 0.2, "NHWC", 3 + ) + sparse_output = run_func(sparse_func, params, x_np) + np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": test_bsr_sparse_conv2d_nhwc() test_bsr_sparse_conv2d_nchw() + test_bsr_sparse_conv2d_3x3_nhwc() + test_bsr_sparse_conv2d_3x3_nchw() From ceac94fd867c7cdf2bb8455b502cbfdc24a6a140 Mon Sep 17 00:00:00 2001 From: Tantalus Date: Tue, 10 Aug 2021 22:37:19 +0800 Subject: [PATCH 20/21] revert some black format --- python/tvm/relay/op/strategy/x86.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index e3293bfdf3bd..1c8d1b478cb1 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -319,7 +319,9 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): # or packed layouts. if layout == "NCDHW": strategy.add_implementation( - wrap_compute_conv3d(topi.nn.conv3d_ncdhw), naive_schedule, name="conv3d_ncdhw.x86", + wrap_compute_conv3d(topi.nn.conv3d_ncdhw), + naive_schedule, + name="conv3d_ncdhw.x86", ) elif layout == "NDHWC": strategy.add_implementation( @@ -438,7 +440,9 @@ def matmul_strategy_cpu(attrs, inputs, out_type, target): "Recommend to use cblas/mkl/mkldnn for better performance." ) strategy.add_implementation( - wrap_compute_matmul(topi.nn.matmul), naive_schedule, name="matmul.generic", + wrap_compute_matmul(topi.nn.matmul), + naive_schedule, + name="matmul.generic", ) return strategy From fb732dd8a2c34fe4f6dd52a86f8dfc8dbea2b80c Mon Sep 17 00:00:00 2001 From: Tantalus Date: Wed, 25 Aug 2021 08:29:00 -0400 Subject: [PATCH 21/21] empty commit to trigger ci again