-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay/TOPI][OP] Add arange op in Relay and TOPI #2621
Changes from all commits
4cfc03b
1bead2c
0b1740a
510c1cb
5a134a4
9a40b91
f78bd5f
5ac01e8
0280015
c8b63f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -880,6 +880,63 @@ and type as the input array. | |
.set_attr<FTVMCompute>("FTVMCompute", FullLikeCompute) | ||
.set_attr<TOpPattern>("TOpPattern", kElemWise); | ||
|
||
// arange operator | ||
TVM_REGISTER_NODE_TYPE(ArangeAttrs); | ||
|
||
bool ArangeRel(const Array<Type>& types, | ||
int num_inputs, | ||
const Attrs& attrs, | ||
const TypeReporter& reporter) { | ||
CHECK_EQ(types.size(), 1); | ||
const ArangeAttrs* param = attrs.as<ArangeAttrs>(); | ||
IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil( | ||
tvm::cast(tvm::Float(32), param->stop - param->start) / param->step)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Thanks. |
||
if (const tvm::ir::IntImm* val = num_elem.as<tvm::ir::IntImm>()) { | ||
CHECK_GT(val->value, 0) | ||
<< "Invalid arange attributes (start, stop, step): " << param->start | ||
<< ", " << param->stop << ", " << param->step; | ||
} | ||
reporter->Assign(types[0], TensorTypeNode::make({num_elem}, param->dtype)); | ||
return true; | ||
} | ||
|
||
Array<Tensor> ArangeCompute(const Attrs& attrs, | ||
const Array<Tensor>& inputs, | ||
const Type& out_type, | ||
const Target& target) { | ||
const ArangeAttrs* param = attrs.as<ArangeAttrs>(); | ||
return { topi::arange(param->start, param->stop, param->step, param->dtype) }; | ||
} | ||
|
||
Expr MakeArange(tvm::Expr start, | ||
tvm::Expr stop, | ||
tvm::Expr step, | ||
DataType dtype) { | ||
auto attrs = make_node<ArangeAttrs>(); | ||
attrs->start = std::move(start); | ||
attrs->stop = std::move(stop); | ||
attrs->step = std::move(step); | ||
attrs->dtype = std::move(dtype); | ||
static const Op& op = Op::Get("arange"); | ||
return CallNode::make(op, {}, Attrs(attrs), {}); | ||
} | ||
|
||
TVM_REGISTER_API("relay.op._make.arange") | ||
.set_body([](const TVMArgs& args, TVMRetValue* rv) { | ||
runtime::detail::unpack_call<Expr, 4>(MakeArange, args, rv); | ||
}); | ||
|
||
RELAY_REGISTER_OP("arange") | ||
.describe(R"code(Returns evenly spaced values within a given interval. | ||
|
||
)code" TVM_ADD_FILELINE) | ||
.set_attrs_type_key("relay.attrs.ArangeAttrs") | ||
.set_num_inputs(0) | ||
.set_support_level(3) | ||
.add_type_rel("Arange", ArangeRel) | ||
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute) | ||
.set_attr<TOpPattern>("TOpPattern", kInjective); | ||
|
||
// where operator | ||
bool WhereRel(const Array<Type>& types, | ||
int num_inputs, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -457,6 +457,40 @@ def test_infer_type_prelu(): | |
verify_infer_type_prelu((1, 3, 2, 2), None, 1, (1, 3, 2, 2)) | ||
verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3)) | ||
|
||
|
||
def test_arange(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we also have a relay frontend test from mxnet arange There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. will add. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added in the new commit. |
||
def verify_arange(start, stop, step): | ||
dtype = "float32" | ||
if start is None and step is None: | ||
x = relay.arange(stop) | ||
ref_res = np.arange(stop) | ||
elif start is None: | ||
x = relay.arange(stop, step=step) | ||
ref_res = np.arange(stop, step=step) | ||
elif step is None: | ||
x = relay.arange(start, stop) | ||
ref_res = np.arange(start, stop) | ||
else: | ||
x = relay.arange(start, stop, step) | ||
ref_res = np.arange(start, stop, step) | ||
|
||
func = relay.Function([], x) | ||
for target, ctx in ctx_list(): | ||
for kind in ["graph", "debug"]: | ||
intrp = relay.create_executor(kind, ctx=ctx, target=target) | ||
op_res = intrp.evaluate(func)() | ||
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) | ||
verify_arange(None, 20, None) | ||
verify_arange(None, 20, 2) | ||
verify_arange(1, 20, None) | ||
verify_arange(1, 20, 2) | ||
verify_arange(1, 20, 1.5) | ||
verify_arange(1, 20.5, None) | ||
verify_arange(1, 20, 3) | ||
verify_arange(20, 1, -1) | ||
verify_arange(20, 1, -1.5) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_cast() | ||
test_zeros_ones() | ||
|
@@ -480,3 +514,4 @@ def test_infer_type_prelu(): | |
test_squeeze_infer_type() | ||
test_squeeze_bad_axes_infer_type() | ||
test_split_infer_type() | ||
test_arange() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -868,6 +868,19 @@ inline Tensor tensordot(const Tensor& A, | |
return compute(output_shape, func, name, tag); | ||
} | ||
|
||
inline Tensor arange(const Expr start, | ||
const Expr stop, | ||
const Expr step, | ||
Type dtype, | ||
std::string name = "tensor", | ||
std::string tag = kInjective) { | ||
Expr num_elem = tvm::cast(tvm::Int(32), tvm::ceil( | ||
tvm::cast(tvm::Float(32), stop - start) / step)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure if we need to check if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. divide by 0 should be captured by IR when step is constant. |
||
Array<Expr> shape; | ||
return compute({num_elem}, [&](const Array<Var>& indices) { | ||
return tvm::cast(dtype, start + step * indices[0]); | ||
}, name, tag); | ||
} | ||
|
||
} // namespace topi | ||
#endif // TOPI_TRANSFORM_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we consider cases like
start > stop
andstep <= 0
, here? I think we probably need to at least warning or raise exceptions forstep == 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the sanity check in the new commit