diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 37ee6b6e929f..6ae86c0786e5 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -78,6 +78,17 @@ def legalize_dense(attrs, inputs, types): reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_alter_op_layout("nn.dense") +def alter_op_layout_dense(attrs, inputs, tinfos, out_type): + """Alternate the layout of dense""" + return topi.nn.dense_alter_layout(attrs, inputs, tinfos, out_type) + + +# dense_pack +reg.register_strategy("nn.contrib_dense_pack", strategy.dense_pack_strategy) +reg.register_pattern("nn.contrib_dense_pack", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + + # fifo_buffer @reg.register_compute("nn.fifo_buffer") def compute_fifo_buffer(attrs, inputs, out_type): @@ -1130,6 +1141,25 @@ def dense_shape_func(attrs, inputs, _): return ret +@script +def _dense_pack_shape_func(data_shape, weight_shape): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(out.shape[0] - 1): + out[i] = data_shape[i] + out[out.shape[0] - 1] = weight_shape[0] * weight_shape[2] + + return out + + +@reg.register_shape_func("nn.contrib_dense_pack", False) +def dense_pack_shape_func(attrs, inputs, _): + """ + Shape function for dense_pack op. + """ + ret = [_dense_pack_shape_func(inputs[0], inputs[1])] + return ret + + @script def _batch_matmul_shape_func(data_shape, weight_shape): out = output_tensor((data_shape.shape[0],), "int64") diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 562cee5f53bb..0c233a6e3b53 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1435,6 +1435,39 @@ def dense(data, weight, units=None, out_dtype=""): return _make.dense(data, weight, units, out_dtype) +def contrib_dense_pack(data, weight, units=None, out_dtype=""): + """Dense operator. + Applies a linear transformation + + .. math:: + + `Y = X * W^T` + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator, + of shape `(d_1, d_2, ..., d_n, units_in)`. + + weight : tvm.relay.Expr + The transformed weight expressions, 3-D matrix, + of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`. + + units : int, optional + Number of hidden units of the dense transformation. + + out_dtype : str, optional + Specifies the output data type for mixed precision dense, + of shape `(d_1, d_2, ..., d_n, units)`. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.contrib_dense_pack(data, weight, units, out_dtype) + + def fifo_buffer(data, buffer, axis): """FIFO buffer to enable computation reuse in CNNs with sliding indow input diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 3ad75faf4bc1..f35303895ddc 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -731,6 +731,19 @@ def dense_strategy(attrs, inputs, out_type, target): return strategy +@override_native_generic_func("dense_pack_strategy") +def dense_pack_strategy(attrs, inputs, out_type, target): + """dense_pack generic strategy""" + logger.warning("dense_pack is not optimized for this platform.") + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_dense(topi.nn.dense_pack), + wrap_topi_schedule(topi.generic.schedule_dense), + name="dense_pack.generic", + ) + return strategy + + # batch_matmul def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False): """wrap batch_matmul topi compute""" diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index edfaaeefc5df..f33c45b248d6 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -364,7 +364,6 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): def dense_strategy_cpu(attrs, inputs, out_type, target): """dense x86 strategy""" strategy = _op.OpStrategy() - m, _ = inputs[0].shape same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype dtype = inputs[0].dtype u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32" @@ -372,6 +371,13 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): wrap_compute_dense(topi.x86.dense_nopack), wrap_topi_schedule(topi.x86.schedule_dense_nopack), name="dense_nopack.x86", + plevel=5, + ) + + strategy.add_implementation( + wrap_compute_dense(topi.x86.dense_pack), + wrap_topi_schedule(topi.x86.schedule_dense_pack), + name="dense_pack.x86", plevel=10, ) @@ -407,14 +413,18 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): name="dense_mkldnn.x86", plevel=15, ) - with SpecializedCondition(m >= 16): - # this implementation may not be well-optimized, so use plevel=5 for now. - strategy.add_implementation( - wrap_compute_dense(topi.x86.dense_pack), - wrap_topi_schedule(topi.x86.schedule_dense_pack), - name="dense_pack.x86", - plevel=5, - ) + return strategy + + +@dense_pack_strategy.register("cpu") +def dense_pack_strategy_cpu(attrs, inputs, out_type, target): + """dense_pack x86 strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_dense(topi.x86.dense_pack), + wrap_topi_schedule(topi.x86.schedule_dense_pack), + name="dense_pack.x86", + ) return strategy diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index bb6ea90c3fcd..e8ec476b86a5 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name,unused-argument """TVM operator fully connected compute.""" import tvm from tvm import te, auto_scheduler @@ -104,3 +105,72 @@ def dense_legalize(attrs, inputs, types): # not to change by default # pylint: disable=unused-argument return None + + +def dense_pack(data, weight, bias=None, out_dtype=None): + """The default implementation of dense_pack in topi. + + Parameters + ---------- + data : tvm.te.Tensor + 2-D with shape [batch, in_dim] + + weight : tvm.te.Tensor + 2-D with shape [out_dim, in_dim] + + bias : Optional[tvm.te.Tensor] + 1-D with shape [out_dim] + + out_dtype : Optional[str] + The output type. This is used for mixed precision. + + Returns + ------- + output : tvm.te.Tensor + 2-D with shape [batch, out_dim] + """ + if out_dtype is None: + out_dtype = data.dtype + M, K = get_const_tuple(data.shape) # batch, in_dim + N, _, packw_bn = get_const_tuple(weight.shape) # out_dim + N = N * packw_bn + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda y, x: te.sum( + data[y, k].astype(out_dtype) + * weight[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype), + axis=k, + ), + name="T_dense_pack", + tag="dense_pack", + ) + if bias is not None: + C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) + return C + + +@tvm.target.generic_func +def dense_alter_layout(attrs, inputs, tinfos, out_type): + """Change dense layout. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : tvm.relay.Expr + Grouped input symbols + tinfos : list + Input shape and dtype + out_type: type + The output type + + Note + ---- + Unlike other TOPI functions, this function operates on both graph level and operator level. + """ + # not to change by default + return None diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py index 154511010a1c..bb6a7cdd4122 100644 --- a/python/tvm/topi/x86/__init__.py +++ b/python/tvm/topi/x86/__init__.py @@ -39,4 +39,5 @@ from .conv3d_transpose import * from .sparse import * from .conv2d_alter_op import * +from .dense_alter_op import * from .scatter import * diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 15d7a1a310d6..6011f01c2cb0 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name,too-many-locals,unused-variable +# pylint: disable=no-value-for-parameter """x86 dense operators""" from __future__ import absolute_import as _abs import tvm @@ -26,11 +27,12 @@ from tvm.contrib import mkldnn from .utils import get_fp32_len +from .injective import schedule_injective_from_existing from .. import generic, tag from ..utils import traverse_inline, get_const_tuple -def _schedule_dense_pack_template(cfg, s, C): +def _schedule_dense_pack_template(cfg, s, C, O): A, packedB = s[C].op.input_tensors CC = s.cache_write(C, "global") @@ -39,9 +41,10 @@ def _schedule_dense_pack_template(cfg, s, C): yt, yo, yi = cfg["tile_y"].apply(s, C, y) xt, xo, xi = cfg["tile_x"].apply(s, C, x) - s[C].reorder(yt, xt, yo, xo, yi, xi) - xyt = s[C].fuse(yt, xt) - s[C].parallel(xyt) + s[C].reorder(xt, yt, yo, xo, yi, xi) + xyt = s[C].fuse(xt, yt) + if C == O: + s[C].parallel(xyt) xyo = s[C].fuse(yo, xo) s[C].unroll(yi) s[C].vectorize(xi) @@ -51,12 +54,27 @@ def _schedule_dense_pack_template(cfg, s, C): ko, ki = cfg["tile_k"].apply(s, CC, k) s[CC].reorder(ko, ki, y, x) s[CC].vectorize(x) - s[CC].unroll(y) - s[CC].unroll(ki) - z, y, x = s[packedB].op.axis - s[packedB].reorder(z, x, y) - s[packedB].parallel(z) + tile_inner = cfg["tile_inner"].size[-1] + if tile_inner > 1: + yo, yi = s[CC].split(y, tile_inner) + s[CC].reorder(ko, yo, ki, yi, x) + s[CC].unroll(yo) + s[CC].unroll(ki) + s[CC].unroll(yi) + else: + s[CC].unroll(ki) + s[CC].unroll(y) + + if C != O: + y, x = s[O].op.axis + yt, yo, yi = cfg["tile_y"].apply(s, O, y) + xt, xo, xi = cfg["tile_x"].apply(s, O, x) + s[O].reorder(xt, yt, yo, xo, yi, xi) + xyt = s[O].fuse(xt, yt) + s[C].compute_at(s[O], xyt) + s[O].vectorize(xi) + s[O].parallel(xyt) return s @@ -83,11 +101,11 @@ def _schedule_dense_nopack_template(cfg, s, C): def _default_dense_pack_config(cfg, M, N, K): # Generate default schedule for dynamic shape. - if isinstance(M, tvm.tir.Var): + if isinstance(M, (tvm.tir.Var, tvm.tir.Any)): M = 16 - if isinstance(N, tvm.tir.Var): + if isinstance(N, (tvm.tir.Var, tvm.tir.Any)): N = 16 - if isinstance(K, tvm.tir.Var): + if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): K = 16 vec_width = get_fp32_len() @@ -116,15 +134,16 @@ def _default_dense_pack_config(cfg, M, N, K): cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii]) cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii]) cfg["tile_k"] = SplitEntity([K, 1]) + cfg["tile_inner"] = SplitEntity([M // tiley_ii, tiley_ii]) def _default_dense_nopack_config(cfg, M, N, K): # Generate default schedule for dynamic shape. - if isinstance(M, tvm.tir.Var): + if isinstance(M, (tvm.tir.Var, tvm.tir.Any)): M = 16 - if isinstance(N, tvm.tir.Var): + if isinstance(N, (tvm.tir.Var, tvm.tir.Any)): N = 16 - if isinstance(K, tvm.tir.Var): + if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): K = 16 vec_width = get_fp32_len() @@ -146,9 +165,15 @@ def dense_nopack(cfg, data, weight, bias=None, out_dtype=None): M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) # create tuning space - cfg.define_split("tile_y", 32 if isinstance(M, tvm.tir.Var) else M, num_outputs=2) - cfg.define_split("tile_x", 32 if isinstance(N, tvm.tir.Var) else N, num_outputs=2) - cfg.define_split("tile_k", 32 if isinstance(K, tvm.tir.Var) else K, num_outputs=2) + cfg.define_split( + "tile_y", 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, num_outputs=2 + ) + cfg.define_split( + "tile_x", 32 if isinstance(N, (tvm.tir.Var, tvm.tir.Any)) else N, num_outputs=2 + ) + cfg.define_split( + "tile_k", 32 if isinstance(K, (tvm.tir.Var, tvm.tir.Any)) else K, num_outputs=2 + ) if cfg.is_fallback: _default_dense_nopack_config(cfg, M, N, K) @@ -184,23 +209,46 @@ def _callback(op): @autotvm.register_topi_compute("dense_pack.x86") def dense_pack(cfg, data, weight, bias=None, out_dtype=None): - """Compute dense with packing""" + """Compute dense with transformed weight.""" if out_dtype is None: out_dtype = data.dtype M, K = get_const_tuple(data.shape) # batch, in_dim - N, _ = get_const_tuple(weight.shape) # out_dim + if len(weight.shape) == 3: + N, _, packw_bn = get_const_tuple(weight.shape) # out_dim + N = N * packw_bn + else: + N, _ = get_const_tuple(weight.shape) # out_dim # create tuning space - cfg.define_split("tile_y", M, num_outputs=3) - cfg.define_split("tile_x", N, num_outputs=3) - cfg.define_split("tile_k", K, num_outputs=2) + cfg.define_split( + "tile_y", 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, num_outputs=3 + ) + cfg.define_split( + "tile_x", 32 if isinstance(N, (tvm.tir.Var, tvm.tir.Any)) else N, num_outputs=3 + ) + cfg.define_split( + "tile_k", 32 if isinstance(K, (tvm.tir.Var, tvm.tir.Any)) else K, num_outputs=2 + ) + cfg.define_split( + "tile_inner", + 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, + num_outputs=2, + filter=lambda y: y.size[-1] <= 16, + ) if cfg.is_fallback: _default_dense_pack_config(cfg, M, N, K) - packw_bn = cfg["tile_x"].size[-1] - packw_shape = (N // packw_bn, K, packw_bn) - packw = te.compute( - packw_shape, lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight" - ) + if len(weight.shape) == 2: + packw_bn = cfg["tile_x"].size[-1] + packw_shape = (N // packw_bn, K, packw_bn) + if autotvm.GLOBAL_SCOPE.in_tuning: + # Directly use modified data layout placeholder. + packw = tvm.te.placeholder(packw_shape, weight.dtype, name="packed_weight") + else: + packw = te.compute( + packw_shape, lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight" + ) + else: + packw = weight idxdiv = tvm.tir.indexdiv idxmod = tvm.tir.indexmod @@ -226,7 +274,7 @@ def schedule_dense_pack(cfg, outs): def _callback(op): if "dense_pack" in op.tag: - _schedule_dense_pack_template(cfg, s, op.output(0)) + _schedule_dense_pack_template(cfg, s, op.output(0), outs[0]) traverse_inline(s, outs[0].op, _callback) return s @@ -276,7 +324,19 @@ def dense_mkl(cfg, data, weight, bias=None, out_dtype=None): @autotvm.register_topi_schedule("dense_mkl.x86") def schedule_dense_mkl(_, outs): """Create schedule for dense_mkl""" - return generic.schedule_extern(outs) + # return generic.schedule_extern(outs) + s = te.create_schedule([x.op for x in outs]) + te.schedule.AutoInlineInjective(s) + + def _callback(op): + if "broadcast" in op.tag or "injective" in op.tag or "elemwise" in op.tag: + schedule_injective_from_existing(s, op.output(0)) + + # traverse_inline(s, outs[0].op, _callback) + for out in outs: + if "dense" not in out.op.name: + schedule_injective_from_existing(s, out) + return s @autotvm.register_topi_compute("dense_mkldnn.x86") diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py new file mode 100644 index 000000000000..5e15c8bf5368 --- /dev/null +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Dense alter op functions for x86""" + +import tvm +from tvm import te +from tvm import relay +from tvm import autotvm +from .dense import _default_dense_pack_config +from ..utils import get_const_tuple +from ..nn import dense_alter_layout + + +@dense_alter_layout.register(["cpu", "arm_cpu"]) +def _alter_dense_layout(attrs, inputs, tinfos, out_type): + target = tvm.target.Target.current(allow_none=False) + dispatch_ctx = autotvm.task.DispatchContext.current + data_tensor, weight_tensor = tinfos + out_dtype = out_type.dtype + M, K = get_const_tuple(data_tensor.shape) + N, _ = get_const_tuple(weight_tensor.shape) + + impl, outs = relay.backend.compile_engine.select_implementation( + relay.op.get("nn.dense"), attrs, tinfos, out_type, target + ) + workload = autotvm.task.get_workload(outs) + if workload: + cfg = dispatch_ctx.query(target, workload) + topi_impl = workload[0] + if topi_impl == "dense_pack.x86": + if cfg.is_fallback: + _default_dense_pack_config(cfg, M, N, K) + packw_bn = cfg["tile_x"].size[-1] + weight_layout = "NK%dn" % packw_bn + new_weight = te.placeholder( + (N // packw_bn, K, packw_bn), + dtype=weight_tensor.dtype, + ) + # Relay dense doesn't have bias. + new_workload = autotvm.task.args_to_workload( + [ + data_tensor, + new_weight, + None, + out_dtype, + ], + topi_impl, + ) + dispatch_ctx.update(target, new_workload, cfg) + weight_transform = relay.layout_transform(inputs[1], "NK", weight_layout) + return relay.nn.contrib_dense_pack(inputs[0], weight_transform, None, out_dtype) + + return None diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 8ace82be9ff8..3e3d94c614c3 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -186,6 +186,33 @@ RELAY_REGISTER_OP("nn.dense") .set_support_level(1) .add_type_rel("Dense", DenseRel); +// relay.nn.contrib_dense_pack +// Positional relay function to create dense_pack operator used by frontend FFI. +Expr MakeDensePack(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { + auto attrs = make_object(); + attrs->units = units; + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("nn.contrib_dense_pack"); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_pack").set_body_typed(MakeDensePack); + +RELAY_REGISTER_OP("nn.contrib_dense_pack") + .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. + +- **data**: `(x1, x2, ..., xn, input_dim)` +- **weight**: `(units // pack_weight_tile, input_dim, pack_weight_tile)` +- **out**: `(x1, x2, ..., xn, units)`. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("weight", "3D Tensor", "Packed weight matrix.") + .set_support_level(10) + .add_type_rel("DensePack", DensePackRel); + // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 9b9cff2dba81..c00e2e02b369 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -31,6 +31,8 @@ #include +#include "../op_common.h" + namespace tvm { namespace relay { @@ -88,6 +90,29 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } +template +bool DensePackRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr || weight == nullptr) return false; + + const AttrType* param = attrs.as(); + ICHECK(param != nullptr); + + Array oshape = data->shape; + oshape.Set((oshape.size() - 1), weight->shape[0] * weight->shape[2]); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_NN_H_ diff --git a/tests/python/relay/test_autotvm_task_extraction.py b/tests/python/relay/test_autotvm_task_extraction.py index da71ac37f695..d6bfd8d0ec11 100644 --- a/tests/python/relay/test_autotvm_task_extraction.py +++ b/tests/python/relay/test_autotvm_task_extraction.py @@ -60,9 +60,9 @@ def test_task_extraction(): tasks = autotvm.task.extract_from_program( mod["main"], target=target, params=params, ops=(dense,) ) - assert len(tasks) == 1 + assert len(tasks) == 2 tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,)) - assert len(tasks) == 1 + assert len(tasks) == 2 mod, params, _ = get_network("resnet-18", batch_size=1) mod_list.append(mod) @@ -70,13 +70,13 @@ def test_task_extraction(): tasks = autotvm.task.extract_from_program( mod["main"], target=target, params=params, ops=(conv2d, dense) ) - assert len(tasks) == 13 + assert len(tasks) == 14 tasks = autotvm.task.extract_from_program( mod, target=target, params=params, ops=(conv2d, dense) ) - assert len(tasks) == 13 + assert len(tasks) == 14 tasks = autotvm.task.extract_from_program(mod, target=target, params=params) - assert len(tasks) == 13 + assert len(tasks) == 14 mod, params, _ = get_network("resnet3d-18", batch_size=1) tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(conv3d,)) @@ -88,7 +88,7 @@ def test_task_extraction(): tasks = autotvm.task.extract_from_program( mod, target=target, params=params, ops=(conv2d, dense) ) - assert len(tasks) == 20 + assert len(tasks) == 21 mod, params, _ = get_network("dcgan", batch_size=1) tasks = autotvm.task.extract_from_program( diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 58c279d750ec..41186884bdb2 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -18,7 +18,7 @@ import pytest import tvm -from tvm import relay +from tvm import relay, topi from tvm.relay import transform, analysis from tvm.relay.testing.temp_op_attr import TempOpAttr from tvm.relay.testing import run_infer_type @@ -1248,6 +1248,34 @@ def expected(): assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + str(a) +def test_alter_op_dense(): + def before(): + x = relay.var("x", shape=(32, 64)) + weight = relay.var("weight", shape=(48, 64)) + y = relay.nn.dense(x, weight) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(32, 64)) + weight = relay.var("weight", shape=(48, 64)) + target_layout = "NK16n" + weight_transform = relay.layout_transform(weight, "NK", target_layout) + y = relay.nn.contrib_dense_pack(x, weight_transform, units=None, out_dtype="float32") + y = relay.Function(analysis.free_vars(y), y) + return y + + for target, _ in tvm.testing.enabled_targets(): + with tvm.target.Target(target): + with TempOpAttr( + "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout + ): + a = before() + a = run_opt_pass(a, transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b) + + if __name__ == "__main__": test_alter_op() test_alter_return_none() @@ -1269,3 +1297,4 @@ def expected(): test_alter_layout_nhwc_arm() test_alter_layout_nhwc_int8_aarch64() test_alter_op_with_global_var() + test_alter_op_dense() diff --git a/tutorials/micro/micro_tflite.py b/tutorials/micro/micro_tflite.py index c28918380265..c979216d0c6b 100644 --- a/tutorials/micro/micro_tflite.py +++ b/tutorials/micro/micro_tflite.py @@ -195,7 +195,7 @@ # Now, compile the model for the target: with tvm.transform.PassContext( - opt_level=3, config={"tir.disable_vectorize": True}, disabled_pass=["FuseOps"] + opt_level=3, config={"tir.disable_vectorize": True}, disabled_pass=["FuseOps", "AlterOpLayout"] ): graph, c_mod, c_params = relay.build(mod, target=TARGET, params=params)