-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[relay][topi] Add operation relay.nn.dilate() which calls topi.nn.dilate() #5331
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -458,6 +458,15 @@ def compute_cross_entropy(attrs, inputs, out_dtype): | |
reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE) | ||
|
||
|
||
# dilate | ||
@reg.register_compute("nn.dilate") | ||
def compute_dilate(attrs, inputs, out_dtype): | ||
return [topi.nn.dilate(inputs[0], attrs.strides)] | ||
|
||
reg.register_broadcast_schedule("nn.dilate") | ||
reg.register_pattern("nn.dilate", OpPattern.OPAQUE) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if the op pattern is better to be injective. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kazum Done! |
||
|
||
|
||
# cross_entropy_with_logits | ||
@reg.register_compute("nn.cross_entropy_with_logits") | ||
def compute_cross_entropy_with_logits(attrs, inputs, out_dtype): | ||
|
@@ -653,6 +662,21 @@ def pad_shape_func(attrs, inputs, _): | |
pad_width.append(get_const_tuple(pair)) | ||
return [_pad_shape_func(inputs[0], convert(pad_width))] | ||
|
||
@script | ||
def _dilate_shape_func(data_shape, strides): | ||
out = output_tensor((data_shape.shape[0],), "int64") | ||
for i in const_range(out.shape[0]): | ||
out[i] = (data_shape[i] - 1) * strides[i] + 1 | ||
|
||
return out | ||
|
||
@reg.register_shape_func("nn.dilate", False) | ||
def dilate_shape_func(attrs, inputs, _): | ||
""" | ||
Shape function for dilate op. | ||
""" | ||
return [_dilate_shape_func(inputs[0], convert(attrs.strides))] | ||
|
||
reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) | ||
reg.register_shape_func("nn.softmax", False, elemwise_shape_func) | ||
reg.register_shape_func("nn.relu", False, elemwise_shape_func) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1347,6 +1347,25 @@ def pad(data, | |
return _make.pad(data, pad_width, pad_value, pad_mode) | ||
|
||
|
||
def dilate(data, strides): | ||
"""Dilate data with zeros. | ||
|
||
Parameters | ||
---------- | ||
data : tvm.relay.Expr | ||
n-D, can be any layout. | ||
|
||
strides : <tuple of n <int> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
Dilation stride on each dimension, 1 means no dilation. | ||
|
||
Returns | ||
------- | ||
Output : tvm.relay.Expr | ||
The computed result | ||
""" | ||
return _make.dilate(data, strides) | ||
|
||
|
||
def mirror_pad(data, | ||
pad_width, | ||
mode="SYMMETRIC"): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -961,6 +961,56 @@ Do log on the data - do not accept logits. | |
.add_type_rel("CrossEntropy", CrossEntropyRel); | ||
|
||
|
||
///// | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove ////, down also |
||
// relay.nn.dilate | ||
TVM_REGISTER_NODE_TYPE(DilateAttrs); | ||
|
||
bool DilateRel(const Array<Type>& types, | ||
int num_inputs, | ||
const Attrs& attrs, | ||
const TypeReporter& reporter) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Align the arguments |
||
CHECK_EQ(types.size(), 2); | ||
const auto* x = types[0].as<TensorTypeNode>(); | ||
const DilateAttrs* param = attrs.as<DilateAttrs>(); | ||
if (x == nullptr) return false; | ||
CHECK_EQ(x->shape.size(), param->strides.size()); | ||
|
||
std::vector<IndexExpr> oshape; | ||
for (size_t i = 0; i < param->strides.size(); ++i) { | ||
if (!x->shape[i].as<tir::AnyNode>()) { | ||
oshape.push_back((x->shape[i] - 1) * param->strides[i] + 1); | ||
} else { | ||
oshape.push_back(x->shape[i]); | ||
} | ||
} | ||
|
||
reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), x->dtype)); | ||
return true; | ||
} | ||
|
||
// Positional relay function to create dilate operator used by frontend FFI. | ||
Expr MakeDilate(Expr data, Array<IndexExpr> strides) { | ||
auto attrs = make_object<DilateAttrs>(); | ||
attrs->strides = std::move(strides); | ||
static const Op& op = Op::Get("nn.dilate"); | ||
return Call(op, {data}, Attrs(attrs), {}); | ||
} | ||
|
||
|
||
TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate") | ||
.set_body_typed(MakeDilate); | ||
|
||
|
||
RELAY_REGISTER_OP("nn.dilate") | ||
.describe(R"code( | ||
Dilate data with zeros. | ||
)code" TVM_ADD_FILELINE) | ||
.set_num_inputs(1) | ||
.add_argument("x", "1D Tensor", "Data to dilate.") | ||
.set_support_level(10) | ||
.add_type_rel("Dilate", DilateRel); | ||
///// | ||
|
||
// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI. | ||
Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { | ||
static const Op& op = Op::Get("nn.cross_entropy_with_logits"); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove ////, down also