Skip to content

Commit

Permalink
[Relay] Alter Op Layout (#2150)
Browse files Browse the repository at this point in the history
* [RELAY] Finish alter op pass

* [RELAY] AlterOpLayout Pass

* fix broadcast operators

* fix broadcast operators

* fix broadcast operators

* Support concatenate

* address comments

* address comments

* add comments

* rebase
  • Loading branch information
merrymercy authored Nov 30, 2018
1 parent 4bf1fd8 commit 2a5656b
Show file tree
Hide file tree
Showing 35 changed files with 1,498 additions and 217 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/HalideIR
5 changes: 5 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
int groups;
std::string data_layout;
std::string weight_layout;
std::string out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") {
Expand Down Expand Up @@ -139,6 +140,10 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
.describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
Expand Down
13 changes: 13 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,19 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
}
};


struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout;
std::string dst_layout;

TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(src_layout)
.describe("The source layout of the tensor. (e.g. NCHW)");
TVM_ATTR_FIELD(dst_layout)
.describe("The destination layout of the tensor. (e.g. NCHW16c)");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
2 changes: 1 addition & 1 deletion include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed";
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ using FTVMSchedule = runtime::TypedPackedFunc<
const Array<Tensor>& outs,
const Target& target)>;

/*!
* \brief Alternate the layout of operators or replace the
* operator with other expressions. This function will be invoked
* in AlterOpLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
*/
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<Tensor>& tinfos)>;

/*!
* \brief Forward rewriting rule for a specific op.
*
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <string>

namespace tvm {
Expand Down Expand Up @@ -173,6 +174,21 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
* \param expr The expression.
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
* \return The rewritten expression.
*/
Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);


/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import schedule
from . import module
from . import node
from . import attrs
from . import ir_builder
from . import target
from . import generic
Expand Down
40 changes: 40 additions & 0 deletions python/tvm/attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
from ._ffi.node import NodeBase, register_node as _register_tvm_node
from ._ffi.function import _init_api
from . import _api_internal


@_register_tvm_node
class Attrs(NodeBase):
"""Attribute node, which is mainly use for defining attributes of relay operators.
Used by function registered in python side, such as compute, schedule and alter_layout.
Attrs is passed as the first argument to these functions.
"""
def list_field_info(self):
""" Get fields information
Returns
-------
infos: list of AttrFieldInfo
List of field information
"""
return _api_internal._AttrsListFieldInfo(self)

def keys(self):
"""Get list of names in the attribute.
Returns
-------
keys : list of str
List of keys
"""
fields = self.list_field_info()
for field in fields:
yield field.name

def __getitem__(self, item):
return self.__getattr__(item)


_init_api("tvm.attrs")
14 changes: 14 additions & 0 deletions python/tvm/relay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ def register_relay_node(type_key=None):
return _register_tvm_node(type_key)


def register_relay_attr_node(type_key=None):
"""register relay attribute node
Parameters
----------
type_key : str or cls
The type key of the node
"""
if not isinstance(type_key, str):
return _register_tvm_node(
"relay.attrs." + type_key.__name__)(type_key)
return _register_tvm_node(type_key)


class RelayNode(NodeBase):
"""Base class of all relay node."""
def astext(self, show_meta_data=True, annotate=None):
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
}

class BuildConfig(object):
Expand Down Expand Up @@ -157,6 +158,13 @@ def optimize(func, params=None):

if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)

if cfg.pass_enabled("AlterOpLayout"):
func = ir_pass.infer_type(func)
func = ir_pass.canonicalize_ops(func)
func = ir_pass.infer_type(func)
func = ir_pass.alter_op_layout(func)

return func


Expand Down
36 changes: 36 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,23 @@ def simplify_inference(expr):
return _ir_pass.simplify_inference(expr)


def canonicalize_ops(expr):
""" Canonicalize special operators to basic operators.
This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.)
Parameters
----------
e: tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
An expression without bias_add
"""
return _ir_pass.canonicalize_ops(expr)


def dead_code_elimination(expr):
""" Remove expressions which does not effect the program result (dead code).
Expand Down Expand Up @@ -321,3 +338,22 @@ def combine_parallel_conv2d(expr):
Transformed expression
"""
return _ir_pass.CombineParallelConv2D(expr)


def alter_op_layout(expr):
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
This pass can be used for computing convolution in custom layouts or
other general weight pre-transformation.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression with alternated layout.
"""
return _ir_pass.AlterOpLayout(expr)
4 changes: 3 additions & 1 deletion python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pylint: disable=wildcard-import, redefined-builtin
"""Relay core operators."""
# operator defs
from .op import get, register, register_schedule, register_compute, Op
from .op import get, register, register_schedule, register_compute, register_alter_op_layout, \
Op

# Operators
from .reduce import *
Expand All @@ -10,6 +11,7 @@
from . import nn
from . import image
from . import vision
from . import op_attrs

# operator registry
from . import _tensor
Expand Down
9 changes: 0 additions & 9 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,3 @@ def clip_compute(attrs, inputs, output_type, target):
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]

register_schedule("clip", schedule_elemwise)
register_pattern("clip", OpPattern.ELEMWISE)

# concatenate
@register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)]

register_schedule("concatenate", schedule_injective)
register_pattern("concatenate", OpPattern.INJECTIVE)
18 changes: 16 additions & 2 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name
# pylint: disable=invalid-name,unused-argument
from __future__ import absolute_import
import topi
from . import op as _reg
from ._reduce import _schedule_reduce
from .op import schedule_injective, OpPattern

schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective
Expand All @@ -15,10 +17,22 @@
_reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("cast", schedule_broadcast)
_reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective)
_reg.register_schedule("where", schedule_broadcast)

# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)

# concatenate
@_reg.register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)]

_reg.register_schedule("concatenate", schedule_injective)
_reg.register_pattern("concatenate", OpPattern.INJECTIVE)
22 changes: 20 additions & 2 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def register_schedule(op_name, schedule=None, level=10):
op_name : str
The name of the op.
schedule : function
schedule : function (attrs: Attrs, outs: List[Tensor], target: Target) -> sch: Schedule
The schedule function.
level : int
Expand All @@ -124,7 +124,8 @@ def register_compute(op_name, compute=None, level=10):
op_name : str
The name of the op.
compute : function
compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type, target:Target)
-> List[Tensor]
The compute function.
level : int
Expand All @@ -133,6 +134,23 @@ def register_compute(op_name, compute=None, level=10):
return register(op_name, "FTVMCompute", compute, level)


def register_alter_op_layout(op_name, alter_layout=None, level=10):
"""Register alter op layout function for an op
Parameters
----------
op_name : str
The name of the operator
alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
The function for changing the layout or replacing the operator
level : int
The priority level
"""
return register(op_name, "FTVMAlterOpLayout", alter_layout, level)


def register_pattern(op_name, pattern, level=10):
"""Register operator pattern for an op.
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""The attributes node used for Relay operators"""

from ...attrs import Attrs
from ..base import register_relay_attr_node

@register_relay_attr_node
class Conv2DAttrs(Attrs):
"""Attribute of a Convolution Operator"""
pass

@register_relay_attr_node
class GlobalPool2DAttrs(Attrs):
"""Attribute of a Global 2D Pooling Operator"""
pass
22 changes: 22 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,25 @@ def slice_like(data, shape_like, axes=None):
The computed result.
"""
return _make.slice_like(data, shape_like, axes)


def layout_transform(data, src_layout, dst_layout):
"""Transform the layout of a tensor
Parameters
----------
data : relay.Expr
The source tensor to be transformed
src_layout: str
The source layout. (e.g NCHW)
dst_layout: str
The destination layout. (e.g. NCHW16c)
Returns
-------
ret : relay.Expr
The transformed tensor.
"""
return _make.layout_transform(data, src_layout, dst_layout)
Loading

0 comments on commit 2a5656b

Please sign in to comment.