Skip to content

Commit

Permalink
[Relay] Improve more operator mxnet frontend importer (apache#2772)
Browse files Browse the repository at this point in the history
  • Loading branch information
oovm authored and wweic committed Mar 12, 2019
1 parent 0128af8 commit cc12f7d
Showing 1 changed file with 64 additions and 8 deletions.
72 changes: 64 additions & 8 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,51 @@ def _mx_leaky_relu(inputs, attrs):
raise RuntimeError("act_type: {} is not supported".format(act_type))


def _mx_make_power(power):
def _impl(inputs, _): # Note: no attrs
assert len(inputs) == 1
scalar = _expr.const(power, dtype=None)
# Note: int maps to "int32", float maps to "float32"
return _op.power(inputs[0], scalar)
return _impl


def _mx_make_exponent(base):
# exp(b, x) = e^b * e^x
def _impl(inputs, _): # Note: no attrs
assert len(inputs) == 1
scalar = _op.exp(_expr.const(base, dtype="float32"))
return _op.multiply(inputs[0], scalar)
return _impl


def _mx_make_logarithm(base):
# log(b, x) = log(x) / log(b)
def _impl(inputs, _): # Note: no attrs
assert len(inputs) == 1
scalar = _op.log(_expr.const(base, dtype="float32"))
return _op.divide(inputs[0], scalar)
return _impl


def _mx_expm1():
# exp_minus_1 x = exp(x) - 1
def _impl(inputs, _): # Note: no attrs
assert len(inputs) == 1
one = _expr.const(1, dtype="float32")
return _op.log(_op.subtract(inputs[0], one))
return _impl


def _mx_log1p():
# 1_plus_log x = log(x + 1)
def _impl(inputs, _): # Note: no attrs
assert len(inputs) == 1
one = _expr.const(1, dtype="float32")
return _op.log(_op.add(inputs[0], one))
return _impl


def _mx_lrn(inputs, attrs):
new_attrs = {}
new_attrs["alpha"] = attrs.get_float("alpha", 0.0001)
Expand Down Expand Up @@ -450,7 +495,6 @@ def _mx_l2_normalize(inputs, attrs):
"exp",
"sigmoid",
"tanh",
"exp",
"negative",
"reshape_like",
"zeros_like",
Expand Down Expand Up @@ -482,6 +526,20 @@ def _mx_l2_normalize(inputs, attrs):
"_minimum" : _rename(_op.minimum),
"flatten" : _rename(_op.nn.batch_flatten),
"Flatten" : _rename(_op.nn.batch_flatten),
# scalar power
"square" : _mx_make_power(2),
"sqrt" : _mx_make_power(1/2),
"rsqrt" : _mx_make_power(-1/2),
"cbrt" : _mx_make_power(1/3),
"rcbrt" : _mx_make_power(-1/3),
"__pow_scalar__" : _binop_scalar(_op.power),
"_power_scalar" : _binop_scalar(_op.power),
"__rsub_scalar__" : _rbinop_scalar(_op.subtract),
"_rminus_scalar" : _rbinop_scalar(_op.subtract),
"__rdiv_scalar__" : _rbinop_scalar(_op.divide),
"_rdiv_scalar" : _rbinop_scalar(_op.divide),
"__rpow_scalar__" : _rbinop_scalar(_op.power),
# scalar op
"__add_scalar__" : _binop_scalar(_op.add),
"_plus_scalar" : _binop_scalar(_op.add),
"__sub_scalar__" : _binop_scalar(_op.subtract),
Expand All @@ -490,13 +548,10 @@ def _mx_l2_normalize(inputs, attrs):
"_mul_scalar" : _binop_scalar(_op.multiply),
"__div_scalar__" : _binop_scalar(_op.divide),
"_div_scalar" : _binop_scalar(_op.divide),
"__pow_scalar__" : _binop_scalar(_op.power),
"_power_scalar" : _binop_scalar(_op.power),
"__rsub_scalar__" : _rbinop_scalar(_op.subtract),
"_rminus_scalar" : _rbinop_scalar(_op.subtract),
"__rdiv_scalar__" : _rbinop_scalar(_op.divide),
"_rdiv_scalar" : _rbinop_scalar(_op.divide),
"__rpow_scalar__" : _rbinop_scalar(_op.power),
"log2" : _mx_make_logarithm(2),
"log10" : _mx_make_logarithm(10),
"log1p" : _mx_log1p,
"expm1" : _mx_expm1,
"_equal_scalar" : _mx_compare(_op.equal, _binop_scalar),
"_not_equal_scalar" : _mx_compare(_op.not_equal, _binop_scalar),
"_greater_scalar" : _mx_compare(_op.greater, _binop_scalar),
Expand All @@ -506,6 +561,7 @@ def _mx_l2_normalize(inputs, attrs):
"_maximum_scalar" : _binop_scalar(_op.maximum),
"_minimum_scalar" : _binop_scalar(_op.minimum),
# reduction ops
"mean" : _reduce(_op.mean),
"max" : _reduce(_op.max),
"min" : _reduce(_op.min),
"sum" : _reduce(_op.sum),
Expand Down

0 comments on commit cc12f7d

Please sign in to comment.