-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay/TOPI][OP] Add arange op in Relay and TOPI #2621
Conversation
include/tvm/relay/attrs/transform.h
Outdated
.describe("Start of interval. The interval includes this value."); | ||
TVM_ATTR_FIELD(stop) | ||
.describe("Stop of interval. The interval does not include this value."); | ||
TVM_ATTR_FIELD(start).set_default(make_const(Int(32), 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/start/step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
relay.arange(5) = [0, 1, 2, 3, 4] | ||
relay.arange(1, 5) = [1, 2, 3, 4] | ||
relay.arange(1, 5, 1.5) = [1, 2.5, 4] | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we consider cases like start > stop
and step <= 0
, here? I think we probably need to at least warning or raise exceptions for step == 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the sanity check in the new commit
CHECK_EQ(types.size(), 1); | ||
const ArangeAttrs* param = attrs.as<ArangeAttrs>(); | ||
IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil( | ||
tvm::cast(tvm::Float(32), param->stop - param->start) / param->step)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CHECK_NE(param->step, 0U)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
step
is not necessary to be constant during the compilation time. So probably we should rely on IR to capture this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks.
std::string name = "tensor", | ||
std::string tag = kInjective) { | ||
Expr num_elem = tvm::cast(tvm::Int(32), tvm::ceil( | ||
tvm::cast(tvm::Float(32), stop - start) / step)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if we need to check if step == 0
, probably it is enough if we checked if before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
divide by 0 should be captured by IR when step is constant.
https://github.com/dmlc/tvm/blob/master/src/lang/ir_operator.cc#L202
Hey @yzhliu @zhreshold, could you help review this PR? |
@@ -457,6 +457,40 @@ def test_infer_type_prelu(): | |||
verify_infer_type_prelu((1, 3, 2, 2), None, 1, (1, 3, 2, 2)) | |||
verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3)) | |||
|
|||
|
|||
def test_arange(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also have a relay frontend test from mxnet arange
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. will add.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in the new commit.
python/tvm/relay/op/transform.py
Outdated
@@ -249,6 +249,49 @@ def full_like(data, fill_value): | |||
return _make.full_like(data, fill_value) | |||
|
|||
|
|||
def arange(stop, start=None, step=1, dtype="float32"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the doc does not match the arg trick though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made a note in docs in the new commit
src/relay/op/tensor/transform.cc
Outdated
IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil( | ||
tvm::cast(tvm::Float(32), param->stop - param->start) / param->step)); | ||
if (const tvm::ir::IntImm* val = num_elem.as<tvm::ir::IntImm>()) { | ||
CHECK_GT(val->value, 0) << "Invalid arange inputs"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to also print related params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in the new commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Thanks all. This is merged. |
* Add arange op * Update docs * Fix bug * add sanity check in relay and mxnet frontend mapping * lint * nits * pylint * don't allow empty output from arange * Remove empty test for arange * Fix bug and update doc
* Add arange op * Update docs * Fix bug * add sanity check in relay and mxnet frontend mapping * lint * nits * pylint * don't allow empty output from arange * Remove empty test for arange * Fix bug and update doc
* Add arange op * Update docs * Fix bug * add sanity check in relay and mxnet frontend mapping * lint * nits * pylint * don't allow empty output from arange * Remove empty test for arange * Fix bug and update doc
Currently I put start, stop, step in Relay attributes since it is required to infer the output shape. Later if Relay supports unknown dimension like
Any
, we can move them into inputs of arange op instead of attributes.This PR relies on #2615.