Skip to content

Commit

Permalink
comment addressed
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Mar 10, 2019
1 parent 3c31cca commit 540195a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1131,12 +1131,12 @@ bool TileRel(const Array<Type>& types,
return false;
}
const auto* param = attrs.as<TileAttrs>();
const size_t ndim = static_cast<size_t>(data->shape.size());
const size_t ndim = data->shape.size();
const Array<Integer>& reps = param->reps;
// check dimension match
CHECK(!reps.defined())
<< "repetition array is not defined. data.ndim = " << ndim;
const size_t rndim = static_cast<size_t>(reps.size());
const size_t rndim = reps.size();
size_t tndim = (ndim > rndim) ? ndim : rndim;
// re-construct data shape or reps shape
std::vector<IndexExpr> data_shape;
Expand Down
26 changes: 13 additions & 13 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -785,43 +785,43 @@ inline Tensor tile(const Tensor& x,
Array<Integer> reps,
std::string name = "tensor",
std::string tag = kBroadcast) {
int ndim = static_cast<int>(x->shape.size());
int rdim = static_cast<int>(reps.size());
int tdim = (ndim > rdim) ? ndim : rdim;
size_t ndim = x->shape.size();
size_t rdim = reps.size();
size_t tdim = (ndim > rdim) ? ndim : rdim;
Array<Expr> data_shape;
Array<Expr> reps_shape;
Array<Expr> new_shape;
if (ndim == rdim) {
for (size_t i = 0; i < static_cast<size_t>(ndim); ++i) {
for (size_t i = 0; i < ndim; ++i) {
data_shape.push_back(x->shape[i]);
reps_shape.push_back(reps[i]);
}
} else if (ndim > rdim) {
for (size_t i = 0; i < static_cast<size_t>(ndim); ++i)
for (size_t i = 0; i < ndim; ++i)
data_shape.push_back(x->shape[i]);
for (size_t i = 0; i < static_cast<size_t>(ndim - rdim); ++i)
for (size_t i = 0; i < (ndim - rdim); ++i)
reps_shape.push_back(1);
for (size_t i = 0; i < static_cast<size_t>(rdim); ++i)
for (size_t i = 0; i < rdim; ++i)
reps_shape.push_back(reps[i]);
} else {
for (size_t i = 0; i < static_cast<size_t>(rdim - ndim); ++i)
for (size_t i = 0; i < (rdim - ndim); ++i)
data_shape.push_back(1);
for (size_t i = 0; i < static_cast<size_t>(ndim); ++i)
for (size_t i = 0; i < ndim; ++i)
data_shape.push_back(x->shape[i]);
for (size_t i = 0; i < static_cast<size_t>(rdim); ++i)
for (size_t i = 0; i < rdim; ++i)
reps_shape.push_back(reps[i]);
}
for (size_t i = 0; i < static_cast<size_t>(tdim); ++i)
for (size_t i = 0; i < tdim; ++i)
new_shape.push_back(data_shape[i] * reps_shape[i]);

return compute(
new_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
if (ndim >= rdim) {
for (size_t i = 0; i < static_cast<size_t>(ndim); ++i)
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indices[i] % x->shape[i]);
} else {
for (size_t i = 0; i < static_cast<size_t>(ndim); ++i)
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indices[rdim - ndim + i] % x->shape[i]);
}
return x(idx);
Expand Down

0 comments on commit 540195a

Please sign in to comment.