Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Tile Op #5983

Merged
merged 5 commits into from
Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
from __future__ import absolute_import

from tvm.runtime import convert
from tvm.te.hybrid import script
from .. import op as _reg

_reg.register_injective_schedule("dyn.reshape")
_reg.register_broadcast_schedule("dyn.tile")

@script
def _reshape_shape_func_input_data(data, newshape, ndim):
Expand Down Expand Up @@ -81,3 +84,40 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
@_reg.register_shape_func("dyn.reshape", True)
def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]


@script
def _tile_shape_func(data, reps, ndim, tndim, rndim):
out = output_tensor((tndim,), "int64")

if ndim == rndim:
for i in const_range(tndim):
out[i] = int64(data.shape[i] * reps[i])
elif ndim > rndim:
ngap = ndim - rndim
for i in const_range(ndim):
if i < ngap:
out[i] = int64(data.shape[i])
else:
out[i] = int64(data.shape[i] * reps[i - ngap])
else:
rgap = rndim - ndim
for i in const_range(rndim):
if i < rgap:
out[i] = int64(reps[i])
else:
out[i] = int64(reps[i] * data.shape[i - rgap])
return out


@_reg.register_shape_func("dyn.tile", True)
def tile_shape_func(attrs, inputs, _):
"""
Shape function for dyn.tile op.
"""
reps = inputs[1]
ndim = len(inputs[0].shape)
rndim = inputs[1].shape[0].value
tndim = ndim if ndim > rndim else rndim
return [_tile_shape_func(inputs[0], reps, convert(ndim),
convert(tndim), convert(rndim))]
5 changes: 3 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def tile(data, reps):
data : relay.Expr
The input data to the operator.

reps : tuple of int
reps : tuple of int or relay.Expr
The number of times repeating the tensor data.

Returns
Expand Down Expand Up @@ -524,7 +524,8 @@ def tile(data, reps):
data is promoted to be d-dimensional by prepending new axes.
If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it.
"""

if isinstance(reps, Expr):
return _dyn_make.tile(data, reps)
return _make.tile(data, reps)


Expand Down
67 changes: 67 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/registry.h>

#include <vector>

namespace tvm {
namespace relay {
namespace dyn {
Expand Down Expand Up @@ -128,6 +130,71 @@ RELAY_REGISTER_OP("dyn.reshape")
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// tile operator
// TVM_REGISTER_NODE_TYPE(TileAttrs);
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved

bool TileRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, reps, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* reps = types[1].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "tile: expect input type to be TensorType but get " << types[0];
return false;
}
if (reps == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "tile: expect input type to be TensorType but get " << types[1];
return false;
}
const IntImmNode* reps_shape = reps->shape[0].as<IntImmNode>();
CHECK(reps_shape) << "Parameter reps must have static shape";
const size_t ndim = data->shape.size();
const size_t rndim = reps_shape->value;
size_t tndim = (ndim > rndim) ? ndim : rndim;
std::vector<IndexExpr> oshape;
oshape.reserve(tndim);
for (size_t i = 0; i < tndim; ++i) {
oshape.emplace_back(Any());
}
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}

Array<te::Tensor> TileCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
CHECK_EQ(inputs.size(), 2);
const auto* out_ttype = out_type.as<TensorTypeNode>();
size_t rndim = inputs[1]->shape[0].as<IntImmNode>()->value;
return {topi::dyn_tile(inputs[0], out_ttype->shape, rndim)};
}

Expr MakeTile(Expr data, Expr reps) {
auto attrs = make_object<TileAttrs>();
static const Op& op = Op::Get("dyn.tile");
return Call(op, {data, reps}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.tile").set_body_typed(MakeTile);

RELAY_REGISTER_OP("dyn.tile")
.describe(R"code(Repeat the whole array multiple times.

- **data**: The input data to the operator.
- **reps**: The number of times to repeat the operator.

)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_attrs_type<TileAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("reps", "Tensor", "The number of times to repeat the input on each axis.")
.set_support_level(3)
.add_type_rel("DynamicTile", TileRel)
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace dyn
} // namespace relay
} // namespace tvm
12 changes: 11 additions & 1 deletion src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace relay {

class DynamicToStaticMutator : public MixedModeMutator {
public:
DynamicToStaticMutator() : dyn_reshape_op_(Op::Get("dyn.reshape")) {}
DynamicToStaticMutator()
: dyn_reshape_op_(Op::Get("dyn.reshape")), dyn_tile_op_(Op::Get("dyn.tile")) {}

private:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Expand All @@ -46,6 +47,14 @@ class DynamicToStaticMutator : public MixedModeMutator {
static const Op& reshape = Op::Get("reshape");
return Call(reshape, {call_node->args[0]}, Attrs(attrs), {});
}
} else if (call_node->op == dyn_tile_op_) {
zhiics marked this conversation as resolved.
Show resolved Hide resolved
if (const ConstantNode* reps = call_node->args[1].as<ConstantNode>()) {
auto attrs = make_object<TileAttrs>();
CHECK_EQ(reps->data->ndim, 1);
attrs->reps = ToVector(reps->data);
static const Op& op = Op::Get("tile");
return Call(op, {call_node->args[0]}, Attrs(attrs), {});
}
}
return post;
}
Expand All @@ -58,6 +67,7 @@ class DynamicToStaticMutator : public MixedModeMutator {
}

const Op& dyn_reshape_op_;
const Op& dyn_tile_op_;
};

Expr DynamicToStatic(Function f, IRModule m) {
Expand Down
15 changes: 15 additions & 0 deletions tests/python/relay/dyn/test_dynamic_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ def verify_reshape(shape, newshape, oshape):
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))

def test_dyn_tile():
def verify_tile(dshape, reps):
x = relay.var("x", relay.TensorType(dshape, "float32"))
r = relay.var("reps", relay.TensorType((len(reps), ), "float32"))
z = relay.tile(x, r)

func = relay.Function([x, r], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.tile(x_data, reps=reps)
verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res)
verify_tile((2, 3, 4), (3, 2, 1))
verify_tile((2, 3, 4), (1, 2))
verify_tile((2, 3), (3, 2, 1))

if __name__ == "__main__":
test_dyn_reshape()
test_dyn_shape_reshape()
test_dyn_tile()
22 changes: 22 additions & 0 deletions tests/python/relay/test_pass_dynamic_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,30 @@ def verify_reshape(shape, newshape):
verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))

def test_dynamic_to_static_tile():
def verify_tile(shape, reps, oshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("y", relay.TensorType(reps, "float32"))
z = relay.tile(x, relay.shape_of(y))
func = run_infer_type(relay.Function([x, y], z))
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

zz = func2.body
assert isinstance(zz, relay.Call)
assert zz.op == relay.op.get("tile")
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")

x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=reps).astype("float32")
ref_res = np.tile(x_data, reps)
verify_func(func2, [x_data, y_data], ref_res)

verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20))
verify_tile((4, 7), (4, 2), (16, 14))

if __name__=="__main__":
test_dynamic_to_static_reshape()
test_dynamic_to_static_double_reshape()
test_dynamic_to_static_quad_reshape()
test_dynamic_to_static_tile()

37 changes: 37 additions & 0 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,43 @@ inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_t
}
}

/*!
* \brief Creates an operation to tile elements of an array
*
* \param x The input tensor
* \param new_shape The shape of the output after tiling
* \param rdim The rank of the reps, provided by caller
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the tile operation
*/
inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim,
std::string name = "T_tile", std::string tag = kBroadcast) {
size_t ndim = x->shape.size();
if (is_empty_shape(new_shape)) {
return compute(
new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
} else {
return compute(
new_shape,
[&](const Array<Var>& indices) {
Array<PrimExpr> idx;
if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i) {
idx.push_back(indexmod(indices[i], x->shape[i]));
}
} else {
for (size_t i = 0; i < ndim; ++i) {
idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
}
}
return x(idx);
},
name, tag);
}
}

/*!
* \brief Gather values along given axis from given indices.
*
Expand Down