Skip to content

Commit

Permalink
[relay][topi] Add operation relay.nn.dilate() which calls topi.nn.dil…
Browse files Browse the repository at this point in the history
…ate() (#5331)

* Add operation relay.nn.dilate() which calls topi.nn.dilate().

* Fix typo

* Set op pattern to injective
  • Loading branch information
notoraptor authored Apr 27, 2020
1 parent a60de36 commit 639358e
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 0 deletions.
10 changes: 10 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,16 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
}
};

/*! \brief Attributes used in dilate operator */
struct DilateAttrs : public tvm::AttrsNode<DilateAttrs> {
Array<IndexExpr> strides;

TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Dilation stride on each dimension, 1 means no dilation.");
}
};

/*! \brief Attributes used in 1D transposed convolution operator */
struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
IndexExpr channels;
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,15 @@ def compute_cross_entropy(attrs, inputs, out_dtype):
reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)


# dilate
@reg.register_compute("nn.dilate")
def compute_dilate(attrs, inputs, out_dtype):
return [topi.nn.dilate(inputs[0], attrs.strides)]

reg.register_broadcast_schedule("nn.dilate")
reg.register_pattern("nn.dilate", OpPattern.INJECTIVE)


# cross_entropy_with_logits
@reg.register_compute("nn.cross_entropy_with_logits")
def compute_cross_entropy_with_logits(attrs, inputs, out_dtype):
Expand Down Expand Up @@ -697,6 +706,21 @@ def pad_shape_func(attrs, inputs, _):
pad_width.append(get_const_tuple(pair))
return [_pad_shape_func(inputs[0], convert(pad_width))]

@script
def _dilate_shape_func(data_shape, strides):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0]):
out[i] = (data_shape[i] - 1) * strides[i] + 1

return out

@reg.register_shape_func("nn.dilate", False)
def dilate_shape_func(attrs, inputs, _):
"""
Shape function for dilate op.
"""
return [_dilate_shape_func(inputs[0], convert(attrs.strides))]

reg.register_shape_func("nn.bias_add", False, elemwise_shape_func)
reg.register_shape_func("nn.softmax", False, elemwise_shape_func)
reg.register_shape_func("nn.relu", False, elemwise_shape_func)
19 changes: 19 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,25 @@ def pad(data,
return _make.pad(data, pad_width, pad_value, pad_mode)


def dilate(data, strides):
"""Dilate data with zeros.
Parameters
----------
data : tvm.relay.Expr
n-D, can be any layout.
strides : <tuple of <int>
Dilation stride on each dimension, 1 means no dilation.
Returns
-------
Output : tvm.relay.Expr
The computed result
"""
return _make.dilate(data, strides)


def mirror_pad(data,
pad_width,
mode="SYMMETRIC"):
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,11 @@ class Conv2DTransposeAttrs(Attrs):
"""Attributes used in Transposed Conv2D operators"""


@tvm._ffi.register_object("relay.attrs.DilateAttrs")
class DilateAttrs(Attrs):
"""Attributes used in dilate operators"""


@tvm._ffi.register_object("relay.attrs.SubPixelAttrs")
class SubPixelAttrs(Attrs):
"""Attributes used in depth to space and space to depth operators"""
48 changes: 48 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,54 @@ Do log on the data - do not accept logits.
.add_type_rel("CrossEntropy", CrossEntropyRel);


// relay.nn.dilate
TVM_REGISTER_NODE_TYPE(DilateAttrs);

bool DilateRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* x = types[0].as<TensorTypeNode>();
const DilateAttrs* param = attrs.as<DilateAttrs>();
if (x == nullptr) return false;
CHECK_EQ(x->shape.size(), param->strides.size());

std::vector<IndexExpr> oshape;
for (size_t i = 0; i < param->strides.size(); ++i) {
if (!x->shape[i].as<tir::AnyNode>()) {
oshape.push_back((x->shape[i] - 1) * param->strides[i] + 1);
} else {
oshape.push_back(x->shape[i]);
}
}

reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), x->dtype));
return true;
}

// Positional relay function to create dilate operator used by frontend FFI.
Expr MakeDilate(Expr data, Array<IndexExpr> strides) {
auto attrs = make_object<DilateAttrs>();
attrs->strides = std::move(strides);
static const Op& op = Op::Get("nn.dilate");
return Call(op, {data}, Attrs(attrs), {});
}


TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate")
.set_body_typed(MakeDilate);


RELAY_REGISTER_OP("nn.dilate")
.describe(R"code(
Dilate data with zeros.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("x", "1D Tensor", "Data to dilate.")
.set_support_level(10)
.add_type_rel("Dilate", DilateRel);

// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
static const Op& op = Op::Get("nn.cross_entropy_with_logits");
Expand Down
28 changes: 28 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,34 @@ def test_any_pad():
verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3))
verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1))

def verify_any_dilate(data_shape, strides, static_data_shape):
assert len(data_shape) == len(strides)
mod = tvm.IRModule()
dtype = "float32"
data = relay.var('data', shape=data_shape, dtype=dtype)
y = relay.nn.dilate(data, strides)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
ref_shape = tuple((static_data_shape[i] - 1) * strides[i] + 1
for i in range(len(static_data_shape)))
ref_out = np.zeros(shape=ref_shape, dtype=dtype)
ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np

for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np)
tvm.testing.assert_allclose(result.asnumpy(), ref_out)

def test_any_dilate():
verify_any_dilate(any_dims(1), (1,), (1,))
verify_any_dilate(any_dims(1), (1,), (5,))
verify_any_dilate(any_dims(1), (5,), (5,))
verify_any_dilate(any_dims(3), (1, 1, 1), (1, 2, 3))
verify_any_dilate(any_dims(3), (1, 1, 2), (1, 2, 3))
verify_any_dilate(any_dims(3), (1, 1, 5), (1, 2, 3))
verify_any_dilate(any_dims(3), (3, 7, 5), (1, 2, 3))
verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4))

def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
Expand Down

0 comments on commit 639358e

Please sign in to comment.