From a2bbe44101259a5198579dc07b19fb7990526cbc Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 29 Nov 2018 17:11:15 -0800 Subject: [PATCH] [Relay] Alter Op Layout (#2150) * [RELAY] Finish alter op pass * [RELAY] AlterOpLayout Pass * fix broadcast operators * fix broadcast operators * fix broadcast operators * Support concatenate * address comments * address comments * add comments * rebase --- 3rdparty/HalideIR | 2 +- include/tvm/relay/attrs/nn.h | 5 + include/tvm/relay/attrs/transform.h | 13 + include/tvm/relay/expr.h | 2 +- include/tvm/relay/op_attr_types.h | 15 + include/tvm/relay/pass.h | 16 + python/tvm/__init__.py | 1 + python/tvm/attrs.py | 40 +++ python/tvm/relay/base.py | 14 + python/tvm/relay/build_module.py | 8 + python/tvm/relay/ir_pass.py | 36 ++ python/tvm/relay/op/__init__.py | 4 +- python/tvm/relay/op/_tensor.py | 9 - python/tvm/relay/op/_transform.py | 18 +- python/tvm/relay/op/op.py | 22 +- python/tvm/relay/op/op_attrs.py | 14 + python/tvm/relay/op/transform.py | 22 ++ src/lang/attrs.cc | 6 + src/relay/op/layout.h | 23 +- src/relay/op/nn/convolution.cc | 31 +- src/relay/op/nn/nn.cc | 19 +- src/relay/op/nn/pad.cc | 175 +++++----- src/relay/op/nn/pooling.cc | 39 ++- src/relay/op/op_common.h | 57 +++- src/relay/op/tensor/binary.cc | 40 +-- src/relay/op/tensor/transform.cc | 146 +++++++- src/relay/op/tensor/unary.cc | 83 ++--- src/relay/pass/alter_op_layout.cc | 312 +++++++++++++++++ src/relay/pass/alter_op_layout.h | 119 +++++++ src/relay/pass/canonicalize_ops.cc | 46 +++ src/relay/pass/fold_scale_axis.cc | 4 +- src/relay/pass/forward_rewrite.cc | 55 ++- src/relay/pass/pattern_util.h | 2 +- .../python/relay/test_pass_alter_op_layout.py | 316 ++++++++++++++++++ topi/include/topi/nn.h | 1 + 35 files changed, 1498 insertions(+), 217 deletions(-) create mode 100644 python/tvm/attrs.py create mode 100644 python/tvm/relay/op/op_attrs.py create mode 100644 src/relay/pass/alter_op_layout.cc create mode 100644 src/relay/pass/alter_op_layout.h create mode 100644 src/relay/pass/canonicalize_ops.cc create mode 100644 tests/python/relay/test_pass_alter_op_layout.py diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index e4a4c02764d37..a08e26e5a97f4 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit e4a4c02764d37c9c3db0d64c4996651a3ef9513c +Subproject commit a08e26e5a97f4ef4d566a42f6c78704b3f9c7b8a diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 817ee04bd844d..724749368aa96 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -105,6 +105,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { int groups; std::string data_layout; std::string weight_layout; + std::string out_layout; DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") { @@ -139,6 +140,10 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "dimensions respectively."); + TVM_ATTR_FIELD(out_layout).set_default("") + .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 3e56106df0c2d..7e614a8cafd44 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -164,6 +164,19 @@ struct ClipAttrs : public tvm::AttrsNode { } }; + +struct LayoutTransformAttrs : public tvm::AttrsNode { + std::string src_layout; + std::string dst_layout; + + TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") { + TVM_ATTR_FIELD(src_layout) + .describe("The source layout of the tensor. (e.g. NCHW)"); + TVM_ATTR_FIELD(dst_layout) + .describe("The destination layout of the tensor. (e.g. NCHW16c)"); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 469b73a1df100..37c91ffe4ed22 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -459,7 +459,7 @@ inline const TTypeNode* ExprNode::type_as() const { static_assert(std::is_base_of::value, "TType must be a special case of type"); CHECK(checked_type_.defined()) - << "Type inference for this Expr has not completed"; + << "Type inference for this Expr has not completed. Try to call infer_type pass."; const TTypeNode* node = checked_type_.as(); CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 3d9fa56855c36..1f37e9947bb8b 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -86,6 +86,21 @@ using FTVMSchedule = runtime::TypedPackedFunc< const Array& outs, const Target& target)>; +/*! + * \brief Alternate the layout of operators or replace the + * operator with other expressions. This function will be invoked + * in AlterOpLayout pass. + * \param attrs The attribute of the original node. + * \param inputs The input symbols of the original node. + * \param tinfos An array of placeholders, use for getting the inferred shape + * and dtype of the inputs. + * \return new_expr The modified expression. + */ +using FTVMAlterOpLayout = runtime::TypedPackedFunc< + Expr(const Attrs& attrs, + const Array& args, + const Array& tinfos)>; + /*! * \brief Forward rewriting rule for a specific op. * diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 58e160eb4ac98..8fff7016a827b 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -8,6 +8,7 @@ #include #include +#include #include namespace tvm { @@ -173,6 +174,21 @@ Expr ForwardRewrite(const Expr& expr, std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); +/*! + * \brief Apply rewrite rules to rewrite the expr in post DFS order. + * \param expr The expression. + * \param rewrite_func The rewrite func that will apply to all operators. + * \param fcontext Additional callback to provide context argument for each call node. + * \param fmulti_ref_trigger Transformation function to be called when + * an Expr consumed by multiple callers. + * \return The rewritten expression. + */ +Expr ForwardRewrite(const Expr& expr, + const FForwardRewrite& rewrite_func, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); + + /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { /*! \brief Hash a Relay type. diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index e202c5adb967d..67dd54d1db4d7 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -13,6 +13,7 @@ from . import schedule from . import module from . import node +from . import attrs from . import ir_builder from . import target from . import generic diff --git a/python/tvm/attrs.py b/python/tvm/attrs.py new file mode 100644 index 0000000000000..529dbcc14c138 --- /dev/null +++ b/python/tvm/attrs.py @@ -0,0 +1,40 @@ +""" TVM Attribute module, which is mainly used for defining attributes of operators""" +from ._ffi.node import NodeBase, register_node as _register_tvm_node +from ._ffi.function import _init_api +from . import _api_internal + + +@_register_tvm_node +class Attrs(NodeBase): + """Attribute node, which is mainly use for defining attributes of relay operators. + + Used by function registered in python side, such as compute, schedule and alter_layout. + Attrs is passed as the first argument to these functions. + """ + def list_field_info(self): + """ Get fields information + + Returns + ------- + infos: list of AttrFieldInfo + List of field information + """ + return _api_internal._AttrsListFieldInfo(self) + + def keys(self): + """Get list of names in the attribute. + + Returns + ------- + keys : list of str + List of keys + """ + fields = self.list_field_info() + for field in fields: + yield field.name + + def __getitem__(self, item): + return self.__getattr__(item) + + +_init_api("tvm.attrs") diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 83aa4ec2cdd06..f1105fe4f0d9d 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -21,6 +21,20 @@ def register_relay_node(type_key=None): return _register_tvm_node(type_key) +def register_relay_attr_node(type_key=None): + """register relay attribute node + + Parameters + ---------- + type_key : str or cls + The type key of the node + """ + if not isinstance(type_key, str): + return _register_tvm_node( + "relay.attrs." + type_key.__name__)(type_key) + return _register_tvm_node(type_key) + + class RelayNode(NodeBase): """Base class of all relay node.""" def astext(self, show_meta_data=True, annotate=None): diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 863ca063137f6..2a2cd9f82ecb3 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -17,6 +17,7 @@ "FoldConstant": 2, "CombineParallelConv2D": 3, "FoldScaleAxis": 3, + "AlterOpLayout": 3, } class BuildConfig(object): @@ -157,6 +158,13 @@ def optimize(func, params=None): if cfg.pass_enabled("FoldConstant"): func = ir_pass.fold_constant(func) + + if cfg.pass_enabled("AlterOpLayout"): + func = ir_pass.infer_type(func) + func = ir_pass.canonicalize_ops(func) + func = ir_pass.infer_type(func) + func = ir_pass.alter_op_layout(func) + return func diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 6297e366070f7..53fa59cd053da 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -191,6 +191,23 @@ def simplify_inference(expr): return _ir_pass.simplify_inference(expr) +def canonicalize_ops(expr): + """ Canonicalize special operators to basic operators. + This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) + + Parameters + ---------- + e: tvm.relay.Expr + The input Expression + + Returns + ------- + result: tvm.relay.Expr + An expression without bias_add + """ + return _ir_pass.canonicalize_ops(expr) + + def dead_code_elimination(expr): """ Remove expressions which does not effect the program result (dead code). @@ -321,3 +338,22 @@ def combine_parallel_conv2d(expr): Transformed expression """ return _ir_pass.CombineParallelConv2D(expr) + + +def alter_op_layout(expr): + """Alternate the layouts of operators or replace primitive operators with + other expressions. + This pass can be used for computing convolution in custom layouts or + other general weight pre-transformation. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + Returns + ------- + transformed_expr : tvm.relay.Expr + Transformed expression with alternated layout. + """ + return _ir_pass.AlterOpLayout(expr) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index b32db4c23f3e0..4a6dfd9f73355 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -1,7 +1,8 @@ #pylint: disable=wildcard-import, redefined-builtin """Relay core operators.""" # operator defs -from .op import get, register, register_schedule, register_compute, Op +from .op import get, register, register_schedule, register_compute, register_alter_op_layout, \ + Op # Operators from .reduce import * @@ -10,6 +11,7 @@ from . import nn from . import image from . import vision +from . import op_attrs # operator registry from . import _tensor diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 75ea3da8af80a..d1035ee047e50 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -80,12 +80,3 @@ def clip_compute(attrs, inputs, output_type, target): return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)] register_schedule("clip", schedule_elemwise) -register_pattern("clip", OpPattern.ELEMWISE) - -# concatenate -@register_compute("concatenate") -def concatenate_compute(attrs, inputs, output_type, target): - return [topi.concatenate(inputs, axis=attrs.axis)] - -register_schedule("concatenate", schedule_injective) -register_pattern("concatenate", OpPattern.INJECTIVE) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 3093032f9e404..1aaf376a7dc81 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1,8 +1,10 @@ """Backend compiler related feature registration""" -# pylint: disable=invalid-name +# pylint: disable=invalid-name,unused-argument from __future__ import absolute_import +import topi from . import op as _reg from ._reduce import _schedule_reduce +from .op import schedule_injective, OpPattern schedule_injective = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective @@ -15,10 +17,22 @@ _reg.register_schedule("reshape_like", schedule_injective) _reg.register_schedule("full", schedule_injective) _reg.register_schedule("full_like", schedule_injective) -_reg.register_schedule("cast", schedule_broadcast) +_reg.register_schedule("cast", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("split", schedule_injective) _reg.register_schedule("take", schedule_injective) _reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("where", schedule_broadcast) + +# layout_transform +_reg.register_schedule("layout_transform", schedule_injective) +_reg.register_pattern("layout_transform", OpPattern.INJECTIVE) + +# concatenate +@_reg.register_compute("concatenate") +def concatenate_compute(attrs, inputs, output_type, target): + return [topi.concatenate(inputs, axis=attrs.axis)] + +_reg.register_schedule("concatenate", schedule_injective) +_reg.register_pattern("concatenate", OpPattern.INJECTIVE) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index c777a82462c85..dd3af9c44e42b 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -107,7 +107,7 @@ def register_schedule(op_name, schedule=None, level=10): op_name : str The name of the op. - schedule : function + schedule : function (attrs: Attrs, outs: List[Tensor], target: Target) -> sch: Schedule The schedule function. level : int @@ -124,7 +124,8 @@ def register_compute(op_name, compute=None, level=10): op_name : str The name of the op. - compute : function + compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type, target:Target) + -> List[Tensor] The compute function. level : int @@ -133,6 +134,23 @@ def register_compute(op_name, compute=None, level=10): return register(op_name, "FTVMCompute", compute, level) +def register_alter_op_layout(op_name, alter_layout=None, level=10): + """Register alter op layout function for an op + + Parameters + ---------- + op_name : str + The name of the operator + + alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr + The function for changing the layout or replacing the operator + + level : int + The priority level + """ + return register(op_name, "FTVMAlterOpLayout", alter_layout, level) + + def register_pattern(op_name, pattern, level=10): """Register operator pattern for an op. diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py new file mode 100644 index 0000000000000..682d56fb9efc5 --- /dev/null +++ b/python/tvm/relay/op/op_attrs.py @@ -0,0 +1,14 @@ +"""The attributes node used for Relay operators""" + +from ...attrs import Attrs +from ..base import register_relay_attr_node + +@register_relay_attr_node +class Conv2DAttrs(Attrs): + """Attribute of a Convolution Operator""" + pass + +@register_relay_attr_node +class GlobalPool2DAttrs(Attrs): + """Attribute of a Global 2D Pooling Operator""" + pass diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c5fedab054d2f..17caad4bb3040 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -387,3 +387,25 @@ def slice_like(data, shape_like, axes=None): The computed result. """ return _make.slice_like(data, shape_like, axes) + + +def layout_transform(data, src_layout, dst_layout): + """Transform the layout of a tensor + + Parameters + ---------- + data : relay.Expr + The source tensor to be transformed + + src_layout: str + The source layout. (e.g NCHW) + + dst_layout: str + The destination layout. (e.g. NCHW16c) + + Returns + ------- + ret : relay.Expr + The transformed tensor. + """ + return _make.layout_transform(data, src_layout, dst_layout) diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 3b273f4939ef7..1daf1e7925534 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -3,6 +3,7 @@ * \file attrs.cc */ #include +#include #include "attr_functor.h" namespace tvm { @@ -321,4 +322,9 @@ bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const { return equal(this->dict, static_cast(other)->dict); } +TVM_REGISTER_API("_AttrsListFieldInfo") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = args[0].operator Attrs()->ListFieldInfo(); +}); + } // namespace tvm diff --git a/src/relay/op/layout.h b/src/relay/op/layout.h index 97160f3cbb9eb..90c920bf3aa11 100644 --- a/src/relay/op/layout.h +++ b/src/relay/op/layout.h @@ -185,7 +185,7 @@ class Layout : public NodeRef { CHECK_GT(block_size, 0); new_layout << block_size; } - new_layout << layout_simplified[i]->value; + new_layout << static_cast(layout_simplified[i]->value); } return Layout(new_layout.str()); } @@ -241,6 +241,16 @@ class Layout : public NodeRef { return operator->()->layout_simplified.size(); } + /*! \return number of super dimensions */ + size_t ndim_super() const { + size_t ct = 0; + for (auto x : operator->()->layout_simplified) { + if (IsSuperdim(x)) + ct++; + } + return ct; + } + /*! * \brief The description of the \p i-th dimension. * If it is a sub-dimension, the size will be returned as well, @@ -327,6 +337,17 @@ class Layout : public NodeRef { return operator->()->name == rhs->name; } + /*! + * \brief allow output string of layout to ostream + * \param os the output stream + * \param l the layout + * \return the ostream + */ + friend std::ostream& operator<<(std::ostream& os, const Layout& l) { + os << l.name(); + return os; + } + using ContainerType = LayoutNode; private: diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index cb648166f7bb6..170b6b6d13c5c 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -7,11 +7,13 @@ #include #include +#include "../../pass/alter_op_layout.h" #include "../layout.h" namespace tvm { namespace relay { +// relay.nn.conv2d TVM_REGISTER_NODE_TYPE(Conv2DAttrs); bool Conv2DRel(const Array& types, @@ -101,6 +103,20 @@ bool Conv2DRel(const Array& types, return true; } +template +Array > Conv2DInferCorrectLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + const T* params = attrs.as(); + Layout out_layout(params->out_layout); + + // We always make other operators to fit the layouts of convolution layers + // So this inference ignores all inputs + return Array >{{params->data_layout, params->weight_layout}, + {out_layout.defined() ? out_layout : params->data_layout}}; +} // Positional relay function to create conv2d operator // used by frontend FFI. @@ -156,10 +172,11 @@ with the layer input to produce a tensor of outputs. .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) -.add_type_rel("Conv2D", Conv2DRel); +.add_type_rel("Conv2D", Conv2DRel) +.set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); -// Conv2DTranspose +// relay.nn.conv2d_transpose TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); bool Conv2DTransposeRel(const Array& types, @@ -185,6 +202,12 @@ bool Conv2DTransposeRel(const Array& types, << "Conv only support kernel layouts that are convertible from OIHW." << " But got "<< kernel_layout; + Layout out_layout(param->out_layout); + if (!out_layout.defined()) out_layout = in_layout; + CHECK(out_layout.Convertible(kNCHW)) + << "Conv only support output layouts that are convertible from NCHW." + << " But got " << out_layout; + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW); @@ -241,7 +264,7 @@ bool Conv2DTransposeRel(const Array& types, if (out_dtype.bits() == 0) { out_dtype = data->dtype; } - oshape = ConvertLayout(oshape, kNCHW, in_layout); + oshape = ConvertLayout(oshape, kNCHW, out_layout); reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); return true; } @@ -307,6 +330,8 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) +.set_attr("FInferCorrectLayout", + Conv2DInferCorrectLayout) .add_type_rel("Conv2DTranspose", Conv2DTransposeRel); } // namespace relay diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index d3b454f35ede3..7ed43d0df0198 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -12,12 +12,14 @@ #include #include #include "../type_relations.h" +#include "../../pass/alter_op_layout.h" #include "../op_common.h" #include "../layout.h" namespace tvm { namespace relay { +// relay.nn.bias_add TVM_REGISTER_NODE_TYPE(BiasAddAttrs); bool BiasAddRel(const Array& types, @@ -74,6 +76,7 @@ RELAY_REGISTER_OP("nn.bias_add") .add_type_rel("BiasAdd", BiasAddRel); +// relay.nn.dense TVM_REGISTER_NODE_TYPE(DenseAttrs); @@ -143,6 +146,8 @@ RELAY_REGISTER_OP("nn.dense") .set_support_level(1) .add_type_rel("Dense", DenseRel); +// relay.leaky_relu +TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); // Positional relay function to create leaky relu operator used by frontend FFI. Expr MakeLeakyRelu(Expr data, @@ -171,6 +176,7 @@ RELAY_REGISTER_OP("nn.leaky_relu") .add_argument("data", "Tensor", "Input data.") .set_support_level(3) .add_type_rel("Identity", IdentityRel) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr( "FTVMCompute", [](const Attrs& attrs, const Array& inputs, @@ -181,6 +187,7 @@ RELAY_REGISTER_OP("nn.leaky_relu") }); +// relay.prelu TVM_REGISTER_NODE_TYPE(PReluAttrs); bool PReluRel(const Array& types, @@ -235,6 +242,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. .add_argument("alpha", "Tensor", "Input channelwise alpha.") .set_support_level(3) .add_type_rel("PRelu", PReluRel) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr( "FTVMCompute", [](const Attrs& attrs, const Array& inputs, @@ -245,6 +253,9 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. }); +// relay.softmax +TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); + TVM_REGISTER_API("relay.op.nn._make.softmax") .set_body([](const TVMArgs& args, TVMRetValue* rv) { auto make_func = [](Expr data, int axis) { @@ -282,6 +293,7 @@ RELAY_REGISTER_OP("nn.softmax") }); +// relay.nn.log_softmax TVM_REGISTER_API("relay.op.nn._make.log_softmax") .set_body([](const TVMArgs& args, TVMRetValue* rv) { auto make_func = [](Expr data, int axis) { @@ -321,8 +333,7 @@ RELAY_REGISTER_OP("nn.log_softmax") }); - -// BatchFlatten +// relay.nn.batch_flatten bool BatchFlattenRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -410,6 +421,7 @@ RELAY_REGISTER_OP("nn.relu") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(1) .add_type_rel("Identity", IdentityRel) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -460,6 +472,7 @@ centered at that value (zero padding is added where necessary). .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .add_type_rel("Identity", IdentityRel); @@ -495,6 +508,7 @@ Normalizes along dimension axis using an L2 norm .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .add_type_rel("Identity", IdentityRel); // Dropout @@ -538,6 +552,7 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input .set_num_inputs(1) .add_argument("data", "Tensor", "Input to which dropout will be applied.") .set_support_level(1) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .add_type_rel("Dropout", DropoutRel); // batch_norm diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 6e02d74e6ea83..5403d0620e500 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -1,87 +1,88 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file pad.cc - * \brief Implementation of operator pad - */ -#include -#include -#include -#include -#include "../layout.h" - -namespace tvm { -namespace relay { - -TVM_REGISTER_NODE_TYPE(PadAttrs); - -bool PadRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) return false; - - const PadAttrs* param = attrs.as(); - CHECK(param != nullptr); - - // check that pad widths match lengths - CHECK(data->shape.size() == param->pad_width.size()) - << "There should be as many pad width pairs as shape dimensions " - << "but the shape has " << data->shape.size() << " dimensions " - << "and there are " << param->pad_width.size() << " pad width pairs."; - - // each pad width element should be a pair of positive integers - std::vector oshape; - for (size_t i = 0; i < param->pad_width.size(); i++) { - CHECK(param->pad_width[i].size() == 2) - << "Each pad width element should be a pair but at index " << i - << " there are " << param->pad_width[i].size() << " elements."; - - auto width1 = as_const_int(param->pad_width[i][0]); - auto width2 = as_const_int(param->pad_width[i][1]); - CHECK(width1 != nullptr); - CHECK(width2 != nullptr); - - CHECK(*width1 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width1 << "."; - CHECK(*width2 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width2 << "."; - - auto padding = make_const(data->shape[i].type(), *width1 + *width2); - oshape.push_back(data->shape[i] + padding); - } - - reporter->Assign(types[1], TensorTypeNode::make(Array(oshape), - data->dtype)); - return true; -} - -// Handler to create a call to the padding op used by front-end FFI -Expr MakePad(Expr data, Array > pad_width, double pad_value) { - auto attrs = make_node(); - attrs->pad_value = pad_value; - attrs->pad_width = std::move(pad_width); - static const Op& op = Op::Get("nn.pad"); - return CallNode::make(op, {data}, Attrs(attrs), {}); -} - -TVM_REGISTER_API("relay.op.nn._make.pad") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakePad, args, rv); - }); - -RELAY_REGISTER_OP("nn.pad") -.describe(R"code(Pad for n-D tensor. - -)code" TVM_ADD_FILELINE) -.set_attrs_type_key("relay.attrs.PadAttrs") -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("Pad", PadRel); - -} // namespace relay -} // namespace tvm +/*! + * Copyright (c) 2018 by Contributors + * \file pad.cc + * \brief Implementation of operator pad + */ +#include +#include +#include +#include +#include "../layout.h" + +namespace tvm { +namespace relay { + +// relay.nn.pad +TVM_REGISTER_NODE_TYPE(PadAttrs); + +bool PadRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const PadAttrs* param = attrs.as(); + CHECK(param != nullptr); + + // check that pad widths match lengths + CHECK(data->shape.size() == param->pad_width.size()) + << "There should be as many pad width pairs as shape dimensions " + << "but the shape has " << data->shape.size() << " dimensions " + << "and there are " << param->pad_width.size() << " pad width pairs."; + + // each pad width element should be a pair of positive integers + std::vector oshape; + for (size_t i = 0; i < param->pad_width.size(); i++) { + CHECK(param->pad_width[i].size() == 2) + << "Each pad width element should be a pair but at index " << i + << " there are " << param->pad_width[i].size() << " elements."; + + auto width1 = as_const_int(param->pad_width[i][0]); + auto width2 = as_const_int(param->pad_width[i][1]); + CHECK(width1 != nullptr); + CHECK(width2 != nullptr); + + CHECK(*width1 >= 0) + << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width1 << "."; + CHECK(*width2 >= 0) + << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width2 << "."; + + auto padding = make_const(data->shape[i].type(), *width1 + *width2); + oshape.push_back(data->shape[i] + padding); + } + + reporter->Assign(types[1], TensorTypeNode::make(Array(oshape), + data->dtype)); + return true; +} + +// Handler to create a call to the padding op used by front-end FFI +Expr MakePad(Expr data, Array > pad_width, double pad_value) { + auto attrs = make_node(); + attrs->pad_value = pad_value; + attrs->pad_width = std::move(pad_width); + static const Op& op = Op::Get("nn.pad"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.pad") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakePad, args, rv); + }); + +RELAY_REGISTER_OP("nn.pad") +.describe(R"code(Pad for n-D tensor. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.PadAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("Pad", PadRel); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 0af0bbf636336..6233e6d51776b 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -9,13 +9,39 @@ #include #include #include "../layout.h" +#include "../../pass/alter_op_layout.h" namespace tvm { namespace relay { +// relay.nn.max_pool2d & relay.nn.avg_pool2d TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); +template +Array > Pool2DInferCorrectLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + // NOTE: Discard "const" qualifier here. + T *params = const_cast(attrs.as()); + + if (new_in_layouts.defined()) { + CHECK_EQ(new_in_layouts.size(), 1); + + Layout raw_layout(params->layout); + Layout input = new_in_layouts[0]; + if (input.Indexof('W') == raw_layout.Indexof('W') && + input.Indexof('H') == raw_layout.Indexof('H') && + !input.Contains('w') && !input.Contains('h')) { + params->layout = input.name(); // modify self to follow the input layout + } + } + + return Array >{{params->layout}, {params->layout}}; +} + template bool Pool2DRel(const Array& types, int num_inputs, @@ -163,6 +189,7 @@ RELAY_REGISTER_OP("nn.max_pool2d") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("MaxPool2D", Pool2DRel) +.set_attr("FInferCorrectLayout", Pool2DInferCorrectLayout) .set_attr("FTVMCompute", Pool2DCompute); @@ -219,9 +246,10 @@ Average pooling operation for one dimensional data. .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("AvgPool2D", Pool2DRel) +.set_attr("FInferCorrectLayout", Pool2DInferCorrectLayout) .set_attr("FTVMCompute", Pool2DCompute); -// Global Pool +// relay.nn.global_pool_2d & relay.nn.max_pool_2d TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs); bool GlobalPool2DRel(const Array& types, @@ -247,8 +275,9 @@ bool GlobalPool2DRel(const Array& types, const auto hidx = layout.Indexof('H'); const auto widx = layout.Indexof('W'); - std::vector oshape({dshape[0], dshape[1], dshape[2], dshape[3]}); - oshape[hidx] = oshape[widx] = 1; + Array oshape(dshape); + oshape.Set(hidx, 1); + oshape.Set(widx, 1); // assign output type reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); @@ -307,6 +336,8 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("GlobalAvgPool2D", GlobalPool2DRel) +.set_attr("FInferCorrectLayout", + Pool2DInferCorrectLayout) .set_attr("FTVMCompute", GlobalPool2DCompute); // GlobalMaxPool @@ -338,6 +369,8 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("GlobalMaxPool2D", GlobalPool2DRel) +.set_attr("FInferCorrectLayout", + Pool2DInferCorrectLayout) .set_attr("FTVMCompute", GlobalPool2DCompute); } // namespace relay diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 5bb2f24cae812..36cd04931903a 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -11,6 +11,7 @@ #include #include #include +#include "../pass/alter_op_layout.h" namespace tvm { namespace relay { @@ -32,21 +33,24 @@ inline std::vector AsVector(const Array &array) { * We make the decision to always only expose positional argument. * We will do rewrapping in the frontend to support language * sugars such as keyword arguments and default value. - * - * \param Prefix the prefix of the registry, for example, "relay.op._make.". - * + * \param OpName the name of registry. */ -#define RELAY_REGISTER_UNARY_OP(Prefix, OpName) \ - TVM_REGISTER_API(Prefix OpName) \ - .set_body_typed([](Expr data) { \ - static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {data}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(1) \ - .add_argument("data", "Tensor", "The input tensor.") \ - .set_attr("TOpPattern", kElemWise) +#define RELAY_REGISTER_UNARY_OP(OpName) \ + TVM_REGISTER_API("relay.op._make." OpName) \ + .set_body_typed([](Expr data) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {data}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(1) \ + .add_argument("data", "Tensor", "The input tensor.") \ + .add_type_rel("Identity", IdentityRel) \ + .set_attr("TOpPattern", kElemWise) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", \ + ElemwiseArbitraryLayout) \ + /*! Quick helper macro * - Expose a positional make function to construct the node. @@ -56,12 +60,10 @@ inline std::vector AsVector(const Array &array) { * We will do rewrapping in the frontend to support language * sugars such as keyword arguments and default value. * - * \param Prefix the prefix of the registry, for example, "relay.op._make.". - * * \param OpName the name of registry. */ -#define RELAY_REGISTER_BINARY_OP(Prefix, OpName) \ - TVM_REGISTER_API(Prefix OpName) \ +#define RELAY_REGISTER_BINARY_OP(OpName) \ + TVM_REGISTER_API("relay.op._make." OpName) \ .set_body_typed([](Expr lhs, Expr rhs) { \ static const Op& op = Op::Get(OpName); \ return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ @@ -72,7 +74,26 @@ inline std::vector AsVector(const Array &array) { .add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_type_rel("Broadcast", BroadcastRel) \ .set_attr("TOpPattern", kBroadcast) \ - .set_attr("TOpIsStateful", false) + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", \ + BinaryBroadcastLayout) + +// Comparisons +#define RELAY_REGISTER_CMP_OP(OpName) \ + TVM_REGISTER_API("relay.op._make." OpName) \ + .set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .add_type_rel("BroadcastComp", BroadcastCompRel) \ + .set_attr("TOpPattern", kBroadcast) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", \ + BinaryBroadcastLayout) } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/binary.cc b/src/relay/op/tensor/binary.cc index 3f28bd52cd4ba..da9b1af875783 100644 --- a/src/relay/op/tensor/binary.cc +++ b/src/relay/op/tensor/binary.cc @@ -23,71 +23,65 @@ namespace relay { // Addition -RELAY_REGISTER_BINARY_OP("relay.op._make.", "add") +RELAY_REGISTER_BINARY_OP("add") .describe("Elementwise add with with broadcasting") .set_support_level(1) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); // Subtraction -RELAY_REGISTER_BINARY_OP("relay.op._make.", "subtract") +RELAY_REGISTER_BINARY_OP("subtract") .describe("Elementwise substract with broadcasting") .set_support_level(1) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); // Right shift -RELAY_REGISTER_BINARY_OP("relay.op._make.", "right_shift") +RELAY_REGISTER_BINARY_OP("right_shift") .describe("Elementwise right shift with broadcasting") .set_support_level(4) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "left_shift") + +RELAY_REGISTER_BINARY_OP("left_shift") .describe("Elementwise left shift with broadcasting") .set_support_level(4) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "maximum") + +RELAY_REGISTER_BINARY_OP("maximum") .describe("Elementwise maximum of two tensors with broadcasting") .set_support_level(4) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "minimum") + +RELAY_REGISTER_BINARY_OP("minimum") .describe("Elementwise minimum of two tensors with broadcasting") .set_support_level(4) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "divide") + +RELAY_REGISTER_BINARY_OP("divide") .describe("Elementwise divide with broadcasting") .set_support_level(1) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply") + +RELAY_REGISTER_BINARY_OP("multiply") .describe("Elementwise multiply with broadcasting") .set_support_level(1) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "power") + +RELAY_REGISTER_BINARY_OP("power") .describe("Elementwise power with broadcasting") .set_support_level(4) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power)); -RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod") + +RELAY_REGISTER_BINARY_OP("mod") .describe("Elementwise mod with broadcasting") .set_support_level(1) .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod)); -// Comparisons -#define RELAY_REGISTER_CMP_OP(OpName) \ - TVM_REGISTER_API("relay.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(2) \ - .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .add_type_rel("BroadcastComp", BroadcastCompRel) \ - .set_attr("TOpPattern", kBroadcast) RELAY_REGISTER_CMP_OP("equal") .describe("Elementwise equal compare with broadcasting") diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 4a052881d7bf3..fcf7f6fe32991 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -11,9 +11,12 @@ #include #include #include +#include #include #include "../op_common.h" #include "../../../arithmetic/compute_expr.h" +#include "../../pass/alter_op_layout.h" +#include "../layout.h" namespace tvm { namespace relay { @@ -156,6 +159,7 @@ RELAY_REGISTER_OP("expand_dims") .set_attr("FTVMCompute", ExpandDimsCompute) .set_attr("TOpPattern", kBroadcast); +// relay.concatenate TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); bool ConcatenateRel(const Array& types, @@ -201,6 +205,42 @@ bool ConcatenateRel(const Array& types, return true; } +Array> ConcatenateLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + const ConcatenateAttrs* param = attrs.as(); + + size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : + static_cast(param->axis); + + Layout ret; + if (new_in_layouts.defined()) { // this function is called after some operators are alternated. + Layout::LayoutDim concate_dim = old_in_layouts[0][axis]; + for (size_t i = 0; i < new_in_layouts.size(); ++i) { + if (new_in_layouts[i].ndim() > axis && + new_in_layouts[i][axis] == concate_dim) { + ret = new_in_layouts[i]; + break; + } + } + } else { // this function is called on the original correct relay ir + for (size_t i = 0; i < old_in_layouts.size(); ++i) { + if (old_in_layouts[i].defined()) { + ret = old_in_layouts[i]; + break; + } + } + + if (ret.ndim() <= axis || Layout::IsSubdim(ret[axis])) { + return Array > {{Layout::Undef()}, {Layout::Undef()}}; + } + } + + return Array > {Array(old_in_layouts.size(), ret), {ret}}; +} + Expr MakeConcatenate(Expr data, int axis) { auto attrs = make_node(); @@ -226,7 +266,8 @@ RELAY_REGISTER_OP("concatenate") .set_num_inputs(1) .add_argument("data", "Tensor", "The input list of tensors.") .set_support_level(1) -.add_type_rel("Concatenate", ConcatenateRel); +.add_type_rel("Concatenate", ConcatenateRel) +.set_attr("FInferCorrectLayout", ConcatenateLayout); /* relay.transpose */ TVM_REGISTER_NODE_TYPE(TransposeAttrs); @@ -323,7 +364,6 @@ RELAY_REGISTER_OP("transpose") .set_attr("TOpPattern", kInjective); /* relay.reshape */ - TVM_REGISTER_NODE_TYPE(ReshapeAttrs); bool ReshapeRel(const Array& types, @@ -1252,7 +1292,7 @@ Examples:: .set_attr("TOpPattern", kInjective); -// Split +// relay.split TVM_REGISTER_NODE_TYPE(SplitAttrs); bool SplitRel(const Array& types, @@ -1367,6 +1407,7 @@ the entries indicate where along axis the array is split. .set_attr("TOpPattern", kInjective); +// relay.slice_like TVM_REGISTER_NODE_TYPE(SliceLikeAttrs); /*! @@ -1513,5 +1554,104 @@ RELAY_REGISTER_OP("slice_like") .set_attr("FTVMCompute", SliceLikeCompute) .set_attr("TOpPattern", kInjective); + +// relay.layout_transform +Array LayoutTransformCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const LayoutTransformAttrs *param = attrs.as(); + CHECK(param != nullptr); + + Layout src_layout(param->src_layout); + Layout dst_layout(param->dst_layout); + + if (src_layout.Equals(dst_layout)) { + return Array{ inputs[0] }; + } + + CHECK(src_layout.defined() && dst_layout.defined()) + << "cannot convert from/to undefined layout"; + CHECK(src_layout.Convertible(dst_layout)) + << "cannot convert from " << param->src_layout << " to " << param->dst_layout; + + const auto& out_shape = ConvertLayout(inputs[0]->shape, src_layout, dst_layout); + return Array { + topi::layout_transform(inputs[0], out_shape, [&](const Array& dst_indices) { + std::vector dst_to_src_indices; + for (size_t i = 0; i < src_layout.ndim(); ++i) { + Layout::LayoutDim src_axis = src_layout[i]; + int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_axis)); + int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_axis)); + int32_t src_factor = static_cast(src_layout.Subsizeof(src_axis)); + int32_t dst_factor = static_cast(dst_layout.Subsizeof(src_axis)); + + tvm::Expr src_index(dst_indices[dst_major_pos]); + if (dst_minor_pos >= 0) { + CHECK_GT(dst_factor, 0); + src_index = src_index * dst_factor + dst_indices[dst_minor_pos]; + } + if (Layout::IsSuperdim(src_axis) && src_factor > 0) { + src_index = src_index / src_factor; + } else if (Layout::IsSubdim(src_axis) && src_factor > 0) { + src_index = src_index % src_factor; + } + dst_to_src_indices.push_back(src_index); + } + return Array(dst_to_src_indices); + }) + }; +} + +bool LayoutTransformRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + const auto* data = types[0].as(); + CHECK(data != nullptr); + const LayoutTransformAttrs* params = attrs.as(); + + Layout src_layout(params->src_layout); + Layout dst_layout(params->dst_layout); + + CHECK(src_layout.defined() && dst_layout.defined()) + << "cannot convert from/to undefined layout"; + CHECK(src_layout.Convertible(dst_layout)) + << "cannot convert from " << params->src_layout << " to " << params->dst_layout; + + const auto& out_shape = ConvertLayout(data->shape, src_layout, dst_layout); + reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype)); + return true; +} + +Expr MakeLayoutTransform(Expr data, + std::string src_layout, + std::string dst_layout) { + auto attrs = make_node(); + attrs->src_layout = std::move(src_layout); + attrs->dst_layout = std::move(dst_layout); + static const Op& op = Op::Get("layout_transform"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.layout_transform") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeLayoutTransform, args, rv); +}); + +RELAY_REGISTER_OP("layout_transform") +.describe(R"code(Transform the input data layout. + +For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes +the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.LayoutTransformAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.add_type_rel("layout_transform", LayoutTransformRel) +.set_support_level(5) +.set_attr("FTVMCompute", LayoutTransformCompute); + } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index fef0302a05074..b83fdacda1ee5 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -22,7 +22,7 @@ namespace relay { } \ -RELAY_REGISTER_UNARY_OP("relay.op._make.", "log") +RELAY_REGISTER_UNARY_OP("log") .describe(R"code(Returns the log input array, computed element-wise. .. math:: @@ -30,11 +30,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "log") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp") +RELAY_REGISTER_UNARY_OP("exp") .describe(R"code(Returns the exp input array, computed element-wise. .. math:: @@ -42,36 +41,30 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); - -RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt") -.describe(R"code(Returns the sqrt input array, computed element-wise. +RELAY_REGISTER_UNARY_OP("sqrt") +.describe(R"code(Returns the rsqrt input array, computed element-wise. .. math:: sqrt(x) )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like") +RELAY_REGISTER_UNARY_OP("zeros_like") .describe(R"code(Returns an array of zeros, with same type and shape as the input. )code" TVM_ADD_FILELINE) -.set_support_level(1) -.add_type_rel("Identity", IdentityRel); +.set_support_level(4); - -RELAY_REGISTER_UNARY_OP("relay.op._make.", "ones_like") +RELAY_REGISTER_UNARY_OP("ones_like") .describe(R"code(Returns an array of ones, with same type and shape as the input. )code" TVM_ADD_FILELINE) -.set_support_level(1) -.add_type_rel("Identity", IdentityRel); +.set_support_level(4); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid") +RELAY_REGISTER_UNARY_OP("sigmoid") .describe(R"code(Returns the sigmoid input array, computed element-wise. .. math:: @@ -79,48 +72,47 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy") +RELAY_REGISTER_UNARY_OP("copy") .describe(R"code(Copy a tensor. )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); // relay.clip TVM_REGISTER_NODE_TYPE(ClipAttrs); TVM_REGISTER_API("relay.op._make.clip") - .set_body_typed([](Expr a, double a_min, double a_max) { - auto attrs = make_node(); - attrs->a_min = a_min; - attrs->a_max = a_max; - static const Op& op = Op::Get("clip"); - return CallNode::make(op, {a}, Attrs(attrs), {}); - }); +.set_body_typed([](Expr a, double a_min, double a_max) { + auto attrs = make_node(); + attrs->a_min = a_min; + attrs->a_max = a_max; + static const Op& op = Op::Get("clip"); + return CallNode::make(op, {a}, Attrs(attrs), {}); +}); RELAY_REGISTER_OP("clip") - .describe(R"code(Clip tensor values. - This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. - )code" TVM_ADD_FILELINE) - .set_num_inputs(1) - .add_argument("tensor", "Tensor", "The input tensor.") - .set_support_level(3) - .add_type_rel("Clip", IdentityRel); - +.describe(R"code(Clip tensor values. +This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.add_type_rel("Identity", IdentityRel) +.set_attr("TOpPattern", kElemWise) +.set_attr("TOpIsStateful", false) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) +.set_support_level(3); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "floor") +RELAY_REGISTER_UNARY_OP("floor") .describe(R"code(Returns the floor of input array, computed element-wise. )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil") +RELAY_REGISTER_UNARY_OP("ceil") .describe(R"code(Returns the ceil of input array, computed element-wise. .. math:: @@ -128,11 +120,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc") +RELAY_REGISTER_UNARY_OP("trunc") .describe(R"code(Returns the trunc of input array, computed element-wise. .. math:: @@ -140,11 +131,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc)); - -RELAY_REGISTER_UNARY_OP("relay.op._make.", "round") +RELAY_REGISTER_UNARY_OP("round") .describe(R"code(Returns the round of input array, computed element-wise. .. math:: @@ -152,11 +141,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "round") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs") +RELAY_REGISTER_UNARY_OP("abs") .describe(R"code(Returns the abs of input array, computed element-wise. .. math:: @@ -164,11 +152,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh") +RELAY_REGISTER_UNARY_OP("tanh") .describe(R"code(Returns the tanh of input array, computed element-wise. .. math:: @@ -176,11 +163,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); -RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative") +RELAY_REGISTER_UNARY_OP("negative") .describe(R"code(Returns the numeric negative of input array, computed element-wise. .. math:: @@ -188,7 +174,6 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative)); } // namespace relay diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc new file mode 100644 index 0000000000000..5c4475259086c --- /dev/null +++ b/src/relay/pass/alter_op_layout.cc @@ -0,0 +1,312 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file alter_op_layout.cc + * \brief Alternate the layouts of operators or replace primitive operators with + other expressions. This pass can be used for computing convolution in + custom layouts or other general weight pre-transformation. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "alter_op_layout.h" + +namespace tvm { +namespace relay { + +namespace alter_op_layout { + +// Make a transform CallNode +Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) { + if (src_layout.Equals(dst_layout)) { return raw; } + CHECK(src_layout.defined() && dst_layout.defined()) + << "Cannot insert layout transform because there are undefined layouts"; + CHECK(src_layout.Convertible(dst_layout)) + << "Cannot insert layout transform because there are inconvertible layouts: " + << src_layout << " v.s. " << dst_layout; + static auto &transform_op = Op::Get("layout_transform"); + NodePtr attrs = make_node(); + attrs->src_layout = src_layout.name(); + attrs->dst_layout = dst_layout.name(); + Call transform = CallNode::make(transform_op, {raw}, Attrs{attrs}); + return transform; +} + +// Memorize layout transform so we can reuse internal transformed nodes +class TransformMemorizerNode : public Node { + public: + // map from (Expr, src_layout, dst_layout) to transformed Expr + using TransformKey = std::tuple; + struct key_hash : public std::unary_function { + std::size_t operator()(const TransformKey& k) const { + return dmlc::HashCombine(dmlc::HashCombine( + std::hash()(std::get<0>(k)), std::get<1>(k)), (std::get<2>(k))); + } + }; + + std::unordered_map memo; + static constexpr const char *_type_key = "relay.alter_op_layout.TransformMemorizerNode"; + TVM_DECLARE_NODE_TYPE_INFO(TransformMemorizerNode, Node); +}; + +class TransformMemorizer : public NodeRef { + public: + TransformMemorizer() {} + explicit TransformMemorizer(NodePtr n) : NodeRef(n) {} + + TransformMemorizerNode* operator->() { + return static_cast(node_.get()); + } + + // Transform layout with memorizer + Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) { + if (src_layout.Equals(dst_layout)) { return raw; } + + std::tuple key = + std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name()); + auto& memo = operator->()->memo; + + auto iter = memo.find(key); + if (iter != memo.end()) { + return iter->second; + } else { + Expr transform = TransformLayout(raw, src_layout, dst_layout); + memo[key] = transform; + return transform; + } + } + + using ContainerType = TransformMemorizerNode; +}; + + +// TempExprNode during layout transform +// Instance of this expr will be Realized to normal expr ultimately +class LayoutAlternatedExprNode : public TempExprNode { + public: + Expr value; + Layout old_layout; + Layout new_layout; + TransformMemorizer memorizer; + + Expr Realize() const final { + // NOTE: use a copy to discard the "const" qualifier + TransformMemorizer tmp_memorizer = memorizer; + // fallback to old layout + return tmp_memorizer.Transform(value, new_layout, old_layout); + } + + void VisitAttrs(AttrVisitor *v) final { + v->Visit("value", &value); + v->Visit("old_layout", &old_layout); + v->Visit("new_layout", &new_layout); + } + + static constexpr const char *_type_key = "relay.alter_op_layout.LayoutAlternatedExprNode"; + TVM_DECLARE_NODE_TYPE_INFO(LayoutAlternatedExprNode, TempExprNode); +}; + +RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr); + +// Call registered FInferCorrectLayout of an op. +// Parameters are the same as the parameters for FInferCorrectLayout +// Returns inferred_input_layout, inferred_output_layout, success +std::tuple, Array, bool> CallInfer( + const Call& call, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array > &old_in_shapes) { + static auto finfer_layout = Op::GetAttr("FInferCorrectLayout"); + + Op op = Downcast(call->op); + if (finfer_layout.count(op)) { + Array > inferred_layouts; + inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts, + old_in_layouts, old_in_shapes); + CHECK_EQ(inferred_layouts.size(), 2) + << "FInferCorrectLayout should return an array with size of 2"; + for (auto x : inferred_layouts) { + for (auto y : x) { + if (!y.defined()) { // inference fails + return std::make_tuple<>(Array(nullptr), Array(nullptr), false); + } + } + } + return std::make_tuple<>(inferred_layouts[0], inferred_layouts[1], true); + } else { + return std::make_tuple<>(Array(nullptr), Array(nullptr), false); + } +} + +// Call registered FTVMAlterOpLayout of an op +// Returns the altered expression +Call CallAlter(const Call& ref_call, + const std::vector& new_args) { + static auto falter_layout = Op::GetAttr("FTVMAlterOpLayout"); + Op op = Downcast(ref_call->op); + + Expr new_e; + bool modified = false; + if (falter_layout.count(op)) { + tvm::Array tinfos; + for (auto expr : ref_call->args) { + auto ttype = expr->type_as(); + tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype)); + } + Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos); + if (altered_value.defined()) { + new_e = altered_value; + modified = true; + } + } + if (!modified) { + new_e = CallNode::make(ref_call->op, new_args, + ref_call->attrs, ref_call->type_args); + } + + const CallNode *new_call = new_e.as(); + CHECK(new_call) << "Can only replace the original operator with another call node"; + return GetRef(new_call); +} + +Expr AlterOpLayoutRewrite(const Call &ref_call, + const Array &new_args, + const NodeRef& ctx) { + std::vector inputs; + std::vector normal_new_args; + Array > input_shapes; + + // NOTE: discard the "const" qualifier + TransformMemorizer memorizer = Downcast(ctx); + + // fill incomplete state and expand tuple + for (auto new_arg : new_args) { + auto push_back_one_arg = [&](Expr arg) { + // We always expect LayoutAlternatedExpr. + // This is used to convert the normal Expr to LayoutAlternatedExpr. + if (const LayoutAlternatedExprNode *inp = arg.as()) { + inputs.push_back(GetRef(inp)); + normal_new_args.push_back(inp->value); + } else { + auto inode = make_node(); + inode->value = arg; + inode->memorizer = memorizer; + inputs.push_back(LayoutAlternatedExpr(inode)); + normal_new_args.push_back(arg); + } + }; + + if (new_arg->is_type()) { + Tuple tuple_new_arg = Downcast(new_arg); + for (auto x : tuple_new_arg->fields) { + push_back_one_arg(x); + } + } else { + push_back_one_arg(new_arg); + } + } + + // old_in, new_in = state[inputs] + Array old_in, old_out, new_in, new_out, new_in2; + for (auto inp : inputs) { + old_in.push_back(inp->old_layout); + new_in.push_back(inp->new_layout); + } + + for (auto arg : ref_call->args) { + if (arg->is_type()) { // expand tuple + Tuple tuple_arg = Downcast(arg); + for (auto x : tuple_arg->fields) { + input_shapes.push_back(x->type_as()->shape); + } + } else { + input_shapes.push_back(arg->type_as()->shape); + } + } + + // old_in, old_out = op.infer(old_in) + bool success = false; + std::tie(old_in, old_out, success) = CallInfer(ref_call, + Array(nullptr), + old_in, input_shapes); + if (!success) { return Expr(nullptr); } + CHECK_EQ(old_in.size(), new_in.size()); + + // if new_in == 'undef': new_in = old_in + for (size_t i = 0; i < new_in.size(); ++i) { + if (!new_in[i].defined()) { + new_in.Set(i, old_in[i]); + } + } + + // new_op = alter(op) + Call new_call = CallAlter(ref_call, normal_new_args); + + // new_in2, new_out = op.infer(new_in) + if (new_call->op->is_type()) { + success = false; + std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, old_in, input_shapes); + if (!success) { return Expr(nullptr); } + } else { + return Expr(nullptr); + } + + CHECK_EQ(new_out.size(), old_out.size()) + << "The number of output nodes should keep the same during alter_op_layout"; + CHECK_EQ(new_in.size(), new_in2.size()) + << "The number of input nodes should keep the same during alter_op_layout"; + + // if (new_in != new_in2): insert transform (new_in -> new_in2) + Array transformed_args; + for (size_t i = 0; i < inputs.size(); ++i) { + transformed_args.push_back(memorizer.Transform(new_call->args[i], new_in[i], new_in2[i])); + } + + // state[node] = (old_out, new_out) + CHECK(ref_call->checked_type_.defined()) + << "Call infer_type pass before alter_op_layout pass"; + + if (ref_call->checked_type()->is_type()) { + Expr tuple_output = CallNode::make(new_call->op, transformed_args, + new_call->attrs, new_call->type_args); + Array fields; + for (size_t i = 0; i < new_out.size(); ++i) { + auto rnode = make_node(); + rnode->value = TupleGetItemNode::make(tuple_output, i); + rnode->old_layout = old_out[i]; + rnode->new_layout = new_out[i]; + rnode->memorizer = memorizer; + fields.push_back(Expr(rnode)); + } + return TupleNode::make(fields); + } else { + auto rnode = make_node(); + CHECK_EQ(new_out.size(), 1); + rnode->value = CallNode::make(new_call->op, transformed_args, + new_call->attrs, new_call->type_args); + rnode->old_layout = old_out[0]; + rnode->new_layout = new_out[0]; + rnode->memorizer = memorizer; + return Expr(rnode); + } +} + +TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") +.set_body([](TVMArgs args, TVMRetValue *ret) { + TransformMemorizer transformMemorizer(make_node()); + auto fcontext = [&](const Call& call) -> NodeRef{ + return transformMemorizer; + }; + + *ret = ForwardRewrite(args[0], AlterOpLayoutRewrite, fcontext); +}); + +} // namespace alter_op_layout + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/alter_op_layout.h b/src/relay/pass/alter_op_layout.h new file mode 100644 index 0000000000000..fcb7b379a0ec1 --- /dev/null +++ b/src/relay/pass/alter_op_layout.h @@ -0,0 +1,119 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file alter_op_layout.h + * \brief Alternate the layouts of operators or replace primitive operators with + other expressions. This pass can be used for computing convolution in + custom layouts or other general weight pre-transformation. + */ + +#ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ +#define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ + +#include + +#include "../op/layout.h" + +namespace tvm { +namespace relay { + +/*! + * \brief Infer & correct function of node layout. See \p Layout for layout convention + * \param attrs The attribute of the node. + * \param new_in_layouts The layouts of input arguments after alter_op_layout. + * This can be undefined, which means we call this function before alternating + * any operators. + * \param old_in_layouts The layouts of input arguments before alter_op_layout. + * \param old_in_shapes The shapes of old input arguments. + * \return infered_layout An array of two elements that are inferred input layouts and + * inferred output layouts. + */ +using FInferCorrectLayout = runtime::TypedPackedFunc< + Array>(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes)>; + +/*! \brief take arbitrary input layout and copy to output */ +inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + Layout ret; + + if (new_in_layouts.defined()) { + CHECK_GE(new_in_layouts.size(), 1); + ret = new_in_layouts[0]; + } else { + for (size_t i = 0; i < old_in_layouts.size(); ++i) { + if (old_in_layouts[i].defined()) { + ret = old_in_layouts[i]; + break; + } + } + } + + return Array >{Array(old_in_layouts.size(), ret), {ret}}; +} + +/*! \brief Infer layout for binary broadcast operators */ +inline Array > BinaryBroadcastLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + Array layouts; + + if (new_in_layouts.defined()) { + layouts.assign(new_in_layouts.begin(), new_in_layouts.end()); + } else { + layouts.assign(old_in_layouts.begin(), old_in_layouts.end()); + } + + if (!layouts[0].defined() && !layouts[1].defined()) { + // both undefined, infer fails + return Array > {{Layout::Undef()}, {Layout::Undef()}}; + } else if (!layouts[0].defined() || !layouts[1].defined()) { + // only one is defined, use shape information to help infer + int defined_idx = layouts[0].defined() ? 0 : 1; + int undef_idx = 1 - defined_idx; + + if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) { + layouts.Set(undef_idx, + layouts[defined_idx].Sublayout( + old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(), + old_in_shapes[undef_idx].size())); + return Array > {layouts, {layouts[defined_idx]}}; + } else { + // only know the tensor with smaller dimensions, + // so we cannot infer the final broadcasted output. + // fails in this case. + return Array > {{Layout::Undef()}, {Layout::Undef()}}; + } + } else { + // try to broadcast the tensors to the larger dimension + int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1; + int small_idx = 1 - large_idx; + Layout ret = layouts[large_idx]; + + // extract common part + size_t i = layouts[large_idx].ndim(); + for (; i != 0; --i) { + auto dim = layouts[large_idx][i-1]; + if (!layouts[small_idx].Contains(Layout::ToSuperdim(dim))) { + break; + } + } + + Layout common_part = layouts[large_idx].Sublayout(i, layouts[large_idx].ndim() - i); + if (!layouts[small_idx].Convertible(common_part)) { // fail + return Array > {{Layout::Undef()}, {Layout::Undef()}}; + } + + layouts.Set(small_idx, common_part); + return Array > {layouts, {ret}}; + } +} + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc new file mode 100644 index 0000000000000..77cd59e2afd86 --- /dev/null +++ b/src/relay/pass/canonicalize_ops.cc @@ -0,0 +1,46 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file canonicalize_ops.cc + * \brief Canonicalize special operators to basic operators. + This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) + */ +#include +#include +#include +#include "pattern_util.h" + +namespace tvm { +namespace relay { + +class BiasAddSimplifier : public ExprMutator { + public: + Expr VisitExpr_(const CallNode* n) { + static const Op& bias_add = Op::Get("nn.bias_add"); + auto new_n = ExprMutator::VisitExpr_(n); + if (n->op.same_as(bias_add)) { + Call call = Downcast(new_n); + CHECK_EQ(call->args.size(), 2); + const BiasAddAttrs* param = call->attrs.as(); + + auto ttype = call->args[0]->type_as(); + size_t n_dim = ttype->shape.size(); + Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {param->axis}); + Expr ret = Add(call->args[0], expanded_bias); + ret->checked_type_ = n->checked_type_; + return ret; + } + return new_n; + } +}; + +Expr CanonicalizeOps(const Expr& e) { + return BiasAddSimplifier().Mutate(e); +} + +TVM_REGISTER_API("relay._ir_pass.canonicalize_ops") +.set_body([](TVMArgs args, TVMRetValue* ret) { +*ret = CanonicalizeOps(args[0]); +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index bcb91e7e57377..c56ee98a3969f 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -29,11 +29,11 @@ using runtime::TypedPackedFunc; // FoldScaleAxis algorithm: // // The general idea is to transform Expr to tuple of -// (value, axes, scale), where the final result satiesfies: +// (value, axes, scale), where the final result satisfies: // // result = value // for i, k in enumerate(axes): -// k-ith dimension of result *= i-th dimension of scale +// k-th dimension of result *= i-th dimension of scale // // Then we can propagate this signal along and fold the scale if necessary. // However, it is possible that certain scale may never be consumed diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 7873db80c6b04..4f33d4a053b75 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -42,13 +42,20 @@ class TempRealizer : private ExprMutator { class ForwardRewriter : private ExprMutator { public: - ForwardRewriter(const OpMap& rewrite_map, + ForwardRewriter(const OpMap* rewrite_map, std::function fcontext, std::function fmulti_ref_trigger) : rewrite_map_(rewrite_map), fcontext_(fcontext), - fmulti_ref_trigger_(fmulti_ref_trigger) { - } + fmulti_ref_trigger_(fmulti_ref_trigger) {} + + ForwardRewriter(const FForwardRewrite* rewrite_func, + std::function fcontext, + std::function fmulti_ref_trigger) + : rewrite_func_(rewrite_func), + fcontext_(fcontext), + fmulti_ref_trigger_(fmulti_ref_trigger) {} + // Transform expression. Expr Rewrite(Expr expr) { @@ -60,8 +67,9 @@ class ForwardRewriter : private ExprMutator { private: // The rewrite rule. - const OpMap& rewrite_map_; - // The context. + const OpMap* rewrite_map_{nullptr}; + const FForwardRewrite* rewrite_func_{nullptr}; + // The context.const std::function fcontext_{nullptr}; // The multiple reference trigger std::function fmulti_ref_trigger_{nullptr}; @@ -104,9 +112,31 @@ class ForwardRewriter : private ExprMutator { } } + Expr VisitExpr_(const TupleNode* op) final { + tvm::Array fields; + bool all_fields_unchanged = true; + for (auto field : op->fields) { + auto new_field = this->GetTempExpr(field); + fields.push_back(new_field); + all_fields_unchanged &= new_field.same_as(field); + } + + if (all_fields_unchanged) { + return GetRef(op); + } else { + return TupleNode::make(fields); + } + } + Expr VisitExpr_(const CallNode* call_node) final { const Call& ref_call = GetRef(call_node); - PackedFunc frewrite = rewrite_map_.get(call_node->op, nullptr); + PackedFunc frewrite; + if (rewrite_func_) { + frewrite = *rewrite_func_; + } else { + CHECK(rewrite_map_); + frewrite = rewrite_map_->get(call_node->op, nullptr); + } auto new_op = this->Mutate(call_node->op); bool unchanged = call_node->op.same_as(new_op); @@ -147,9 +177,16 @@ Expr ForwardRewrite(const Expr& expr, std::function fcontext, std::function fmulti_ref_trigger) { auto rewrite_map = Op::GetAttr(rewrite_map_name); - return ForwardRewriter(rewrite_map, - fcontext, - fmulti_ref_trigger).Rewrite(expr); + return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr); +} + +Expr ForwardRewrite(const Expr& expr, + const FForwardRewrite& rewrite_func, + std::function fcontext, + std::function fmulti_ref_trigger) { + return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); } + + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 38ae923c52744..e6e8415bd620f 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -73,7 +73,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, * the target Tensor on the specified axis via broadcasting rule. * * \param bias The bias. - * \param target_ndim target dimension. + * \param target_ndim Target dimension. * \param axes The axis on the output we want to match on. */ inline Expr ExpandBiasToMatchAxis(Expr bias, diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py new file mode 100644 index 0000000000000..6a8be7ea847eb --- /dev/null +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -0,0 +1,316 @@ +"""Test alter op layout pass""" + +from tvm import relay +from tvm.relay.op import register_alter_op_layout +from tvm.relay.ir_pass import * + +def test_alter_op(): + """Test directly replacing an operator with a new one""" + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var('weight', shape=(64, 64, 3, 3)) + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + @register_alter_op_layout("nn.conv2d", level=100) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + weight = relay.multiply(weight, relay.const(2.0)) + return relay.nn.conv2d(data, weight, **attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var('weight', shape=(64, 64, 3, 3)) + y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0)), + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + a = before() + a = infer_type(a) + a = alter_op_layout(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + + +def test_alter_return_none(): + """Test doing nothing by returning 'None' """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + y = relay.nn.global_max_pool2d(x) + y = relay.Function([x], y) + return y + + called = [False] + + @register_alter_op_layout("nn.global_max_pool2d", level=101) + def alter_conv2d(attrs, inputs, tinfos): + called[0] = True + return None + + a = before() + a = infer_type(a) + a = alter_op_layout(a) + + b = before() + b = infer_type(b) + assert(alpha_equal(a, b)) + assert(called[0]) + + +def test_alter_layout(): + """Test alternating the layout of a conv2d. + The layout of broadcast operators and the weight should be changed accordingly. + """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias") + weight = relay.var("weight") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.bias_add(y, bias) + # a useless tuple, which will be eliminated + y = relay.Tuple([y])[0] + y = relay.nn.relu(y) + y = relay.nn.batch_flatten(y) + y = relay.Function(free_vars(y), y) + return y + + @register_alter_op_layout("nn.conv2d", level=102) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + new_attrs['weight_layout'] = 'OIHW16i' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + + y = relay.layout_transform(x, "NCHW", "NCHW16c") + w = relay.layout_transform(weight, "OIHW", "OIHW16i") + y = relay.nn.conv2d(y, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + weight_layout="OIHW16i", + data_layout="NCHW16c") + b = relay.expand_dims(bias, axis=1, num_newaxis=2) + b = relay.layout_transform(b, "CHW", "CHW16c") + y = relay.add(y, b) + + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.nn.batch_flatten(y) + y = relay.Function(free_vars(y), y) + return y + + a = before() + a = infer_type(a) + a = canonicalize_ops(a) + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + + +def test_alter_layout_dual_path(): + """ + Test alternating the layout with two outputs. + One path continues to use the new layout while one path fall backs to old layout. + """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y1 = relay.nn.relu(y1) + y2 = relay.nn.batch_flatten(y) + ret = relay.Tuple([y1, y2]) + y = relay.Function(free_vars(ret), ret) + return y + + @register_alter_op_layout("nn.conv2d", level=103) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + y = relay.layout_transform(x, "NCHW", "NCHW16c") + y = relay.nn.conv2d(y, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + y = relay.nn.relu(y) + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NCHW16c') + y1 = relay.nn.relu(y1) + y1 = relay.layout_transform(y1, "NCHW16c", "NCHW") + y2 = relay.layout_transform(y, "NCHW16c", "NCHW") + y2 = relay.nn.batch_flatten(y2) + ret = relay.Tuple([y1, y2]) + y = relay.Function(free_vars(ret), ret) + return y + + a = before() + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + +def test_alter_layout_resnet(): + """Test alternating the layout of a residual block + This also tests the elimination of duplicated transformation. + If a same transformation applies to a same node twice, only one transformation will be created. + """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y2 = relay.nn.conv2d(x, weight2, + channels=32, + kernel_size=(1, 1)) + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.nn.global_max_pool2d(y) + return relay.Function(free_vars(y), y) + + @register_alter_op_layout("nn.conv2d", level=104) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + x = relay.layout_transform(x, "NCHW", "NCHW16c") + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + y = relay.nn.relu(y) + y2 = relay.nn.conv2d(x, weight2, + channels=32, + kernel_size=(1, 1), + data_layout='NCHW16c') + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.nn.global_max_pool2d(y, layout="NCHW16c") + y = relay.layout_transform(y, "NCHW16c", "NCHW") + return relay.Function(free_vars(y), y) + + a = before() + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + + +def test_alter_layout_broadcast_op(): + """Test boradcast operators """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias", shape=(64,)) + scale = relay.var("scale", shape=(64, 1, 1)) + weight = relay.var("weight") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.bias_add(y, bias) # test broadcasting to lhs + y = relay.multiply(scale, y) # test broadcasting to rhs + y = relay.Function(free_vars(y), y) + return y + + @register_alter_op_layout("nn.conv2d", level=102) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias", shape=(64,)) + scale = relay.var("scale", shape=(64, 1, 1)) + weight = relay.var("weight") + x = relay.layout_transform(x, "NCHW", "NCHW16c") + bias = relay.expand_dims(bias, 1, 2) + bias = relay.layout_transform(bias, "CHW", "CHW16c") + scale = relay.layout_transform(scale, "CHW", "CHW16c") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1), + data_layout="NCHW16c") + y = relay.add(y, bias) # test broadcasting to lhs + y = relay.multiply(scale, y) # test broadcasting to rhs + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.Function(free_vars(y), y) + return y + + a = before() + a = infer_type(a) + a = canonicalize_ops(a) + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + + +if __name__ == "__main__": + test_alter_op() + test_alter_return_none() + test_alter_layout() + test_alter_layout_dual_path() + test_alter_layout_resnet() + test_alter_layout_broadcast_op() + diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 5fc05162f09b6..9d3e675d8ef7a 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -448,6 +448,7 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I, } using FLayoutIndicesTransform = std::function(const Array& indices)>; + /*! * \brief Transform the layout according to the mapping function \p to_src_indices. * \param src the source input.