Skip to content

Commit

Permalink
add sanity check in relay and mxnet frontend mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Feb 19, 2019
1 parent 7b54202 commit 91eec2b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
11 changes: 11 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,16 @@ def _mx_multibox_detection(inputs, attrs):
return _op.vision.nms(ret[0], ret[1], **new_attrs1)


def _mx_arange(inputs, attrs):
if attrs.get_int("repeat", 1) != 1:
raise RuntimeError("arange doesn't support repeat")
start = attrs.get_float("start", 0)
stop = attrs.get_float("stop")
step = attrs.get_float("step", 1)
dtype = attrs.get_str("dtype", "float32")
return _op.arange(start, stop, step, dtype)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -352,6 +362,7 @@ def _mx_multibox_detection(inputs, attrs):
"Concat" : _mx_concat,
"concat" : _mx_concat,
"LeakyReLU" : _mx_leaky_relu,
"_arange" : _mx_arange,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
# vision
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,13 @@ def arange(stop, start=None, step=1, dtype="float32"):
Parameters
----------
stop : tvm.Expr
Stop of interval. The interval does not include this value.
start : tvm.Expr, optional
Start of interval. The interval includes this value. The default start
value is 0.
stop : tvm.Expr
Stop of interval. The interval does not include this value.
step : tvm.Expr, optional
Spacing between values. The default step size is 1.
Expand Down
3 changes: 3 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,9 @@ bool ArangeRel(const Array<Type>& types,
const ArangeAttrs* param = attrs.as<ArangeAttrs>();
IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil(
tvm::cast(tvm::Float(32), param->stop - param->start) / param->step));
if (const tvm::ir::IntImm* val = num_elem.as<tvm::ir::IntImm>()) {
CHECK_GE(val->value, 0) << "Invalid arange inputs";
}
reporter->Assign(types[0], TensorTypeNode::make({num_elem}, param->dtype));
return true;
}
Expand Down
6 changes: 3 additions & 3 deletions topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,13 @@ def arange(stop, start=None, step=1, dtype="float32"):
Parameters
----------
stop : tvm.Expr
Stop of interval. The interval does not include this value.
start : tvm.Expr, optional
Start of interval. The interval includes this value. The default start
value is 0.
stop : tvm.Expr
Stop of interval. The interval does not include this value.
step : tvm.Expr, optional
Spacing between values. The default step size is 1.
Expand Down

0 comments on commit 91eec2b

Please sign in to comment.