Skip to content

Commit

Permalink
[RELAY][FRONTEND] Initial MXNet frontend support.
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Nov 25, 2018
1 parent 9473dca commit fc0c563
Show file tree
Hide file tree
Showing 39 changed files with 2,213 additions and 171 deletions.
6 changes: 4 additions & 2 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ This level enables additional math and transform operators.
tvm.relay.reshape_like
tvm.relay.copy
tvm.relay.transpose
tvm.relay.squeeze
tvm.relay.floor
tvm.relay.ceil
tvm.relay.trunc
Expand Down Expand Up @@ -114,7 +115,7 @@ This level enables additional math and transform operators.
tvm.relay.less_equal
tvm.relay.maximum
tvm.relay.minimum
tvm.relay.pow
tvm.relay.power
tvm.relay.where
tvm.relay.argmax
tvm.relay.argmin
Expand Down Expand Up @@ -196,6 +197,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.reshape_like
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.squeeze
.. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take
.. autofunction:: tvm.relay.zeros
Expand All @@ -220,7 +222,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.less_equal
.. autofunction:: tvm.relay.maximum
.. autofunction:: tvm.relay.minimum
.. autofunction:: tvm.relay.pow
.. autofunction:: tvm.relay.power
.. autofunction:: tvm.relay.where
.. autofunction:: tvm.relay.argmax
.. autofunction:: tvm.relay.argmin
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis;

TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") {
TVM_ATTR_FIELD(axis).set_default(1)
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("The axis to sum over when computing softmax.");
}
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {

/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<IndexExpr> newshape;
Array<Integer> newshape;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape)
.describe("The new shape. Should be compatible with the original shape.");
Expand Down
4 changes: 2 additions & 2 deletions nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,9 @@ along which to split the array.
return Array<Tensor>{
topi::split_sections(inputs[0], param.indices_or_sections[0], param.axis) };
} else {
Array<Expr> indices;
Array<Integer> indices;
for (auto i : param.indices_or_sections) {
indices.push_back(tvm::make_const(tvm::Int(32), i));
indices.push_back(static_cast<int>(i));
}
return Array<Tensor>{ topi::split(inputs[0], indices, param.axis) };
}
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import expr
from . import module
from . import ir_pass
from .build_module import build, create_executor
from .build_module import build, build_config, create_executor

# Root operators
from .op import Op
Expand All @@ -17,6 +17,7 @@
from . import nn
from . import vision
from . import image
from . import frontend
from . import backend

from .scope_builder import ScopeBuilder
Expand All @@ -40,6 +41,7 @@
scalar_type = ty.scalar_type

# Expr
Expr = expr.Expr
Constant = expr.Constant
Tuple = expr.Tuple
Var = expr.Var
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,18 @@ def lower(self, source_func, target=None):
cached_func: CachedFunc
The result of lowering.
"""
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
# pylint: disable=broad-except
try:
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
except Exception:
import traceback
msg = traceback.format_exc()
msg += "Error during compile func\n"
msg += "--------------------------\n"
msg += source_func.astext(show_meta_data=False)
msg += "--------------------------\n"
raise RuntimeError(msg)

def jit(self, source_func, target=None):
"""JIT a source_func to a tvm.Function.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,4 @@ def _get_unique_name(self, name):
return name
index = self._name_map[name]
self._name_map[name] += 1
return self.get_unique_name(name + str(index))
return self._get_unique_name(name + str(index))
3 changes: 1 addition & 2 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"CombineParallelConv2D": 1,
"CombineParallelConv2D": 4,
"OpFusion": 1,
"FoldConstant": 2,
"FoldScaleAxis": 3,
Expand Down Expand Up @@ -157,7 +157,6 @@ def optimize(func, params=None):

if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)

return func


Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/frontend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Relay frontends."""
from __future__ import absolute_import

from .mxnet import from_mxnet
129 changes: 129 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Common utilities"""
from __future__ import absolute_import as _abs


class RequiredAttr(object):
"""Dummpy class to represent required attr"""
pass


class StrAttrsDict(object):
"""Helper class to parse attrs stored as Dict[str, str].
Parameters
----------
attrs : Dict[str, str]
The attributes to be used.
"""
def __init__(self, attrs):
self.attrs = attrs

def get_float(self, key, default=RequiredAttr()):
"""Get float attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
return float(self.attrs[key])
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default

def get_int(self, key, default=RequiredAttr()):
"""Get int attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
val = self.attrs[key]
if val == "None":
return None
return int(val)
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default

def get_str(self, key, default=RequiredAttr()):
"""Get str attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
return self.attrs[key]
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default

def get_int_tuple(self, key, default=RequiredAttr()):
"""Get int tuple attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
tshape = self.attrs[key]
return tuple(int(x.strip()) for x in tshape.strip('()').split(','))
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default

def get_bool(self, key, default=RequiredAttr()):
"""Get bool tuple attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
val = self.attrs[key]
return val.strip().lower() in ['true', '1', 't', 'y', 'yes']
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
Loading

0 comments on commit fc0c563

Please sign in to comment.