Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Topi]Allow empty tensor for reshape, tile and strided_slice #4618

Merged
merged 6 commits into from
Jan 6, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ bool TakeRel(const Array<Type>& types,
CHECK(data != nullptr);
const auto* indices = types[1].as<TensorTypeNode>();
CHECK(indices != nullptr);
CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
const auto param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);

Expand Down Expand Up @@ -1648,6 +1649,9 @@ bool SqueezeRel(const Array<Type>& types,
// if axes is None, squeeze all axes of dimension 1
if (!param->axis.defined()) {
for (const auto& e : data->shape) {
if (!e.as<IntImm>()) {
LOG(FATAL) << "axis needs to be defined for dynamic input.";
}
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
if (*axis_ptr != 1) {
Expand Down
55 changes: 55 additions & 0 deletions topi/include/topi/detail/tensor_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tensor_utils.h
* \brief Utility functions for handling tensor
*/
#ifndef TOPI_DETAIL_TENSOR_UTILS_H_
#define TOPI_DETAIL_TENSOR_UTILS_H_


namespace topi {
namespace detail {
using namespace tvm;

/*!
* \brief Check whether input shape has dimension of size 0;
*
* \param x Input shape
*
* \return True if the input shape is empty.
*/
inline bool is_empty_shape(const Array<Expr>& x) {
bool is_empty = false;
for (const auto& dim : x) {
if (auto int_dim = dim.as<IntImm>()) {
if (int_dim->value == 0) {
is_empty = true;
break;
}
}
}
return is_empty;
}

} // namespace detail
} // namespace topi
#endif // TOPI_DETAIL_TENSOR_UTILS_H_

59 changes: 39 additions & 20 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "topi/tags.h"
#include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h"
#include "topi/detail/tensor_utils.h"
#include "tvm/operation.h"
#include "tvm/expr_operator.h"
#include "tvm/data_layout.h"
Expand Down Expand Up @@ -207,16 +208,28 @@ inline Tensor reshape(const Tensor& x,
std::string name = "T_reshape",
std::string tag = kInjective) {
auto x_shape = x->shape;
Array<Expr> newshape_int32;
Array<Expr> target_shape;

for (const auto &ele : newshape) {
newshape_int32.push_back(cast(DataType::Int(32), ele));
if (ele.as<IntImm>()) {
target_shape.push_back(cast(DataType::Int(32), ele));
} else {
target_shape.push_back(ele);
}
}

if (is_empty_shape(target_shape)) {
return compute(target_shape,
[&](const Array<Var> &indices) { return tvm::cast(x->dtype, 0); },
name, tag);
} else {
return compute(
target_shape, [&](const Array<Var>& indices) {
return x(UnravelIndex(
RavelIndex(Array<Expr>{indices.begin(), indices.end()}, target_shape),
x_shape));
}, name, tag);
}
return compute(
newshape_int32, [&](const Array<Var>& indices) {
return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape_int32),
x_shape));
}, name, tag);
}

/*!
Expand Down Expand Up @@ -556,7 +569,7 @@ inline Tensor strided_slice(const Tensor& x,
int interval = std::abs(end_i - begin_i);
int slice_size = static_cast<int>((interval
+ std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
CHECK(stride_vec[i] < 0 ? (end_i < begin_i) : (begin_i < end_i))
CHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
<< "] is invalid for axis=" << i;

Expand Down Expand Up @@ -938,18 +951,24 @@ inline Tensor tile(const Tensor& x,
for (size_t i = 0; i < tdim; ++i)
new_shape.push_back(data_shape[i] * reps_shape[i]);

return compute(
new_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indexmod(indices[i], x->shape[i]));
} else {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
}
return x(idx);
}, name, tag);
if (is_empty_shape(new_shape)) {
return compute(new_shape,
[&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0);},
name, tag);
} else {
return compute(
new_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indexmod(indices[i], x->shape[i]));
} else {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
}
return x(idx);
}, name, tag);
}
}

/*!
Expand Down
5 changes: 4 additions & 1 deletion topi/python/topi/arm_cpu/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Schedule for pooling operators"""
import tvm
from .. import generic
from ..util import is_empty_shape

@generic.schedule_injective_from_existing.register(["arm_cpu"])
def schedule_injective_from_existing(sch, out):
Expand Down Expand Up @@ -68,7 +69,9 @@ def schedule_injective(outs):
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 8)
s[x].vectorize(ii)
tvm.schedule.AutoInlineInjective(s)
schedule_injective_from_existing(s, x)

if not is_empty_shape(x.shape):
schedule_injective_from_existing(s, x)
return s

@generic.schedule_concatenate.register(["arm_cpu"])
Expand Down
4 changes: 3 additions & 1 deletion topi/python/topi/cuda/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Schedule for composition of injective operator"""
import tvm
from .. import generic, util
from ..util import is_empty_shape

@generic.schedule_injective_from_existing.register(["cuda", "gpu"])
def schedule_injective_from_existing(sch, out):
Expand Down Expand Up @@ -79,7 +80,8 @@ def schedule_injective(outs):

tvm.schedule.AutoInlineInjective(s)
for out in outs:
schedule_injective_from_existing(s, out)
if not is_empty_shape(out.shape):
schedule_injective_from_existing(s, out)
return s

schedule_elemwise = schedule_injective
Expand Down
22 changes: 22 additions & 0 deletions topi/python/topi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,25 @@ def make_idx(b, e, s, z, i):
(b - i) // tvm.abs(s),
(i - b) // s)
return tvm.if_then_else(tvm.expr.Or(bc, ec), 88, ss)


def is_empty_shape(shape):
"""Check whether an input shape has dimesion with size 0.

Parameter
---------
shape : list of Expr
Input shape

Returns
-------
is_empty: bool
Whether input shape is empty or has dimesion with size 0.
"""
is_empty = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

equivalent to the following (same for C++ code) ?

    for dim in shape:
        if isinstance(dim, tvm.expr.IntImm):
            if dim.value == 0:
                return True
    return False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged.

for dim in shape:
if isinstance(dim, tvm.expr.IntImm):
if dim.value == 0:
is_empty = True
break
return is_empty
5 changes: 4 additions & 1 deletion topi/python/topi/x86/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import absolute_import as _abs
import tvm
from .. import generic
from ..util import is_empty_shape

@generic.schedule_injective_from_existing.register(["cpu"])
def schedule_injective_from_existing(sch, out):
Expand Down Expand Up @@ -65,7 +66,9 @@ def schedule_injective(outs):
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
schedule_injective_from_existing(s, x)

if not is_empty_shape(x.shape):
schedule_injective_from_existing(s, x)
return s

@generic.schedule_concatenate.register(["cpu"])
Expand Down
3 changes: 3 additions & 0 deletions topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2])
verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])

def test_strided_set():
verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2])
Expand Down Expand Up @@ -596,6 +597,7 @@ def test_reshape():
verify_reshape((4, 2, 3, 4), (2, 4, 12))
verify_reshape((4, 2, 3, 4), (2, 48))
verify_reshape((16, ), (2, 2, 2, 2))
verify_reshape((4, 0), (2, 0, 2))


def test_where():
Expand Down Expand Up @@ -718,6 +720,7 @@ def test_tile():
verify_tile((3, 2), (2, 3))
verify_tile((3, 2, 5), (2,))
verify_tile((3, ), (2, 3, 3))
verify_tile((4, 0), (5,))

def test_layout_transform():
in_shape = (1, 32, 8, 8)
Expand Down