Skip to content

Commit

Permalink
Dynamic Tile Op (apache#5983)
Browse files Browse the repository at this point in the history
* first working dynamic tile passes first test

* add dyn tile to dynamic_to_static

* fix cpplintt

* respond to review comments. Thanks @siju-samuel

* make dynamic tile compatible with numpy API
  • Loading branch information
Matthew Brookhart authored and trevor-m committed Jul 14, 2020
1 parent 38804ef commit 499e9c8
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 3 deletions.
40 changes: 40 additions & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
from __future__ import absolute_import

from tvm.runtime import convert
from tvm.te.hybrid import script
from .. import op as _reg

_reg.register_injective_schedule("dyn.reshape")
_reg.register_broadcast_schedule("dyn.tile")

@script
def _reshape_shape_func_input_data(data, newshape, ndim):
Expand Down Expand Up @@ -81,3 +84,40 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
@_reg.register_shape_func("dyn.reshape", True)
def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]


@script
def _tile_shape_func(data, reps, ndim, tndim, rndim):
out = output_tensor((tndim,), "int64")

if ndim == rndim:
for i in const_range(tndim):
out[i] = int64(data.shape[i] * reps[i])
elif ndim > rndim:
ngap = ndim - rndim
for i in const_range(ndim):
if i < ngap:
out[i] = int64(data.shape[i])
else:
out[i] = int64(data.shape[i] * reps[i - ngap])
else:
rgap = rndim - ndim
for i in const_range(rndim):
if i < rgap:
out[i] = int64(reps[i])
else:
out[i] = int64(reps[i] * data.shape[i - rgap])
return out


@_reg.register_shape_func("dyn.tile", True)
def tile_shape_func(attrs, inputs, _):
"""
Shape function for dyn.tile op.
"""
reps = inputs[1]
ndim = len(inputs[0].shape)
rndim = inputs[1].shape[0].value
tndim = ndim if ndim > rndim else rndim
return [_tile_shape_func(inputs[0], reps, convert(ndim),
convert(tndim), convert(rndim))]
5 changes: 3 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def tile(data, reps):
data : relay.Expr
The input data to the operator.
reps : tuple of int
reps : tuple of int or relay.Expr
The number of times repeating the tensor data.
Returns
Expand Down Expand Up @@ -524,7 +524,8 @@ def tile(data, reps):
data is promoted to be d-dimensional by prepending new axes.
If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it.
"""

if isinstance(reps, Expr):
return _dyn_make.tile(data, reps)
return _make.tile(data, reps)


Expand Down
67 changes: 67 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/registry.h>

#include <vector>

namespace tvm {
namespace relay {
namespace dyn {
Expand Down Expand Up @@ -128,6 +130,71 @@ RELAY_REGISTER_OP("dyn.reshape")
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// tile operator
// TVM_REGISTER_NODE_TYPE(TileAttrs);

bool TileRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, reps, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* reps = types[1].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "tile: expect input type to be TensorType but get " << types[0];
return false;
}
if (reps == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "tile: expect input type to be TensorType but get " << types[1];
return false;
}
const IntImmNode* reps_shape = reps->shape[0].as<IntImmNode>();
CHECK(reps_shape) << "Parameter reps must have static shape";
const size_t ndim = data->shape.size();
const size_t rndim = reps_shape->value;
size_t tndim = (ndim > rndim) ? ndim : rndim;
std::vector<IndexExpr> oshape;
oshape.reserve(tndim);
for (size_t i = 0; i < tndim; ++i) {
oshape.emplace_back(Any());
}
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}

Array<te::Tensor> TileCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
CHECK_EQ(inputs.size(), 2);
const auto* out_ttype = out_type.as<TensorTypeNode>();
size_t rndim = inputs[1]->shape[0].as<IntImmNode>()->value;
return {topi::dyn_tile(inputs[0], out_ttype->shape, rndim)};
}

Expr MakeTile(Expr data, Expr reps) {
auto attrs = make_object<TileAttrs>();
static const Op& op = Op::Get("dyn.tile");
return Call(op, {data, reps}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.tile").set_body_typed(MakeTile);

RELAY_REGISTER_OP("dyn.tile")
.describe(R"code(Repeat the whole array multiple times.
- **data**: The input data to the operator.
- **reps**: The number of times to repeat the operator.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_attrs_type<TileAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("reps", "Tensor", "The number of times to repeat the input on each axis.")
.set_support_level(3)
.add_type_rel("DynamicTile", TileRel)
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace dyn
} // namespace relay
} // namespace tvm
12 changes: 11 additions & 1 deletion src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace relay {

class DynamicToStaticMutator : public MixedModeMutator {
public:
DynamicToStaticMutator() : dyn_reshape_op_(Op::Get("dyn.reshape")) {}
DynamicToStaticMutator()
: dyn_reshape_op_(Op::Get("dyn.reshape")), dyn_tile_op_(Op::Get("dyn.tile")) {}

private:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Expand All @@ -46,6 +47,14 @@ class DynamicToStaticMutator : public MixedModeMutator {
static const Op& reshape = Op::Get("reshape");
return Call(reshape, {call_node->args[0]}, Attrs(attrs), {});
}
} else if (call_node->op == dyn_tile_op_) {
if (const ConstantNode* reps = call_node->args[1].as<ConstantNode>()) {
auto attrs = make_object<TileAttrs>();
CHECK_EQ(reps->data->ndim, 1);
attrs->reps = ToVector(reps->data);
static const Op& op = Op::Get("tile");
return Call(op, {call_node->args[0]}, Attrs(attrs), {});
}
}
return post;
}
Expand All @@ -58,6 +67,7 @@ class DynamicToStaticMutator : public MixedModeMutator {
}

const Op& dyn_reshape_op_;
const Op& dyn_tile_op_;
};

Expr DynamicToStatic(Function f, IRModule m) {
Expand Down
15 changes: 15 additions & 0 deletions tests/python/relay/dyn/test_dynamic_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ def verify_reshape(shape, newshape, oshape):
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))

def test_dyn_tile():
def verify_tile(dshape, reps):
x = relay.var("x", relay.TensorType(dshape, "float32"))
r = relay.var("reps", relay.TensorType((len(reps), ), "float32"))
z = relay.tile(x, r)

func = relay.Function([x, r], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.tile(x_data, reps=reps)
verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res)
verify_tile((2, 3, 4), (3, 2, 1))
verify_tile((2, 3, 4), (1, 2))
verify_tile((2, 3), (3, 2, 1))

if __name__ == "__main__":
test_dyn_reshape()
test_dyn_shape_reshape()
test_dyn_tile()
22 changes: 22 additions & 0 deletions tests/python/relay/test_pass_dynamic_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,30 @@ def verify_reshape(shape, newshape):
verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))

def test_dynamic_to_static_tile():
def verify_tile(shape, reps, oshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("y", relay.TensorType(reps, "float32"))
z = relay.tile(x, relay.shape_of(y))
func = run_infer_type(relay.Function([x, y], z))
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

zz = func2.body
assert isinstance(zz, relay.Call)
assert zz.op == relay.op.get("tile")
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")

x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=reps).astype("float32")
ref_res = np.tile(x_data, reps)
verify_func(func2, [x_data, y_data], ref_res)

verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20))
verify_tile((4, 7), (4, 2), (16, 14))

if __name__=="__main__":
test_dynamic_to_static_reshape()
test_dynamic_to_static_double_reshape()
test_dynamic_to_static_quad_reshape()
test_dynamic_to_static_tile()

37 changes: 37 additions & 0 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,43 @@ inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_t
}
}

/*!
* \brief Creates an operation to tile elements of an array
*
* \param x The input tensor
* \param new_shape The shape of the output after tiling
* \param rdim The rank of the reps, provided by caller
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the tile operation
*/
inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim,
std::string name = "T_tile", std::string tag = kBroadcast) {
size_t ndim = x->shape.size();
if (is_empty_shape(new_shape)) {
return compute(
new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
} else {
return compute(
new_shape,
[&](const Array<Var>& indices) {
Array<PrimExpr> idx;
if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i) {
idx.push_back(indexmod(indices[i], x->shape[i]));
}
} else {
for (size_t i = 0; i < ndim; ++i) {
idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
}
}
return x(idx);
},
name, tag);
}
}

/*!
* \brief Gather values along given axis from given indices.
*
Expand Down

0 comments on commit 499e9c8

Please sign in to comment.