Skip to content

Commit

Permalink
[TensorFlow][Frontend] Adding InversePermutation Op (apache#8277)
Browse files Browse the repository at this point in the history
* [TensorFlow][Frontend] Adding InversePermutation Op

Computes the inverse permutation of a tensor. This Op is used by Mask R-CNN
or other object detection models.

* uncomment test_read_variable_op

* restore several tests

* fix lint error

* fix python linting error

* fix lint error

* restore mistakenly deleted codes
  • Loading branch information
cailun01 authored and ylc committed Jan 13, 2022
1 parent fbae2ed commit dc1f54b
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,6 +2886,7 @@ def _impl(inputs, attr, params, mod):
"GreaterEqual": _broadcast("greater_equal"),
"Identity": _identity(),
"IdentityN": _identityn(),
"InvertPermutation": AttrCvt("invert_permutation"),
"IsFinite": AttrCvt("isfinite"),
"IsInf": AttrCvt("isinf"),
"IsNan": AttrCvt("isnan"),
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ def compute_unique(attrs, inputs, output_type):

_reg.register_strategy("unique", strategy.unique_strategy)

# invert_permutation
_reg.register_strategy("invert_permutation", strategy.invert_permutation_strategy)
_reg.register_shape_func("invert_permutation", False, elemwise_shape_func)

#####################
# Shape functions #
#####################
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,3 +1135,15 @@ def schedule_transpose_cuda(attrs, outs, target):
):
return topi.cuda.schedule_transpose(outs)
return schedule_injective(attrs, outs, target)


@invert_permutation_strategy.register(["cuda", "gpu"])
def invert_permutation_strategy_cuda(attrs, inputs, out_type, target):
"""invert_permutation cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_invert_permutation(topi.cuda.invert_permutation),
wrap_topi_schedule(topi.cuda.vision._default_schedule),
name="invert_permutation.cuda",
)
return strategy
22 changes: 22 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,3 +1615,25 @@ def schedule_transpose(attrs, outs, target):
"""schedule transpose"""
with target:
return schedule_injective(attrs, outs, target)


# invert_permutation
def wrap_compute_invert_permutation(topi_compute):
"""wrap invert_permutation topi compute"""

def _compute_invert_permutation(attrs, inputs, out_type):
return [topi_compute(inputs[0])]

return _compute_invert_permutation


@override_native_generic_func("invert_permutation_strategy")
def invert_permutation_strategy(attrs, inputs, out_type, target):
"""invert_permutation generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_invert_permutation(topi.invert_permutation),
wrap_topi_schedule(topi.generic.schedule_injective),
name="invert_permutation.generic",
)
return strategy
28 changes: 28 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,3 +1759,31 @@ def unique(data, is_sorted=True, return_counts=False):
if return_counts:
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 5)
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4)


def invert_permutation(data):
"""Computes the inverse permutation of data.
This operation computes the inverse of an index permutation.
It takes a 1-D integer tensor x, which represents the indices of a zero-based
array and swaps each value with its index position.
For an output tensor y and an input tensor x, this operation computes the following:
y[x[i]] = i for i in [0, 1, ..., len(x) - 1]
Parameters
----------
data : relay.Expr
The source data to be invert permuated.
Returns
-------
ret : relay.Expr
Invert permuated data. Has the same type as data.
Examples
--------
.. code-block:: python
data = [3, 4, 0, 2, 1]
relay.invert_permutation(data) = [2, 4, 3, 0, 1]
"""
return _make.invert_permutation(data)
73 changes: 72 additions & 1 deletion python/tvm/topi/cuda/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""CUDA implementations of transforms"""

import tvm
from ... import te
from ...target import Target
from ..utils import traverse_inline
Expand Down Expand Up @@ -65,3 +65,74 @@ def _callback(op):
s[c].bind(ao, thread_y)

traverse_inline(s, out.op, _callback)


def _invert_permutation_ir(data, out):
"""Low level IR to get invert_permutation.
Parameters
----------
data : Buffer
Input data. 1-D Buffer with shape [elem_num].
out : Buffer
1D buffer for invert permutation result with the same shape with data.
Returns
-------
stmt : Stmt
The result IR statement.
"""
elem_num = data.shape[0]

irb = tvm.tir.ir_builder.create()
data = irb.buffer_ptr(data)
out = irb.buffer_ptr(out)

max_threads = int(Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = elem_num // max_threads + 1
thread_x = te.thread_axis("threadIdx.x")
block_x = te.thread_axis("blockIdx.x")
irb.scope_attr(thread_x, "thread_extent", nthread_tx)
irb.scope_attr(block_x, "thread_extent", nthread_bx)
tid = block_x * max_threads + thread_x

with irb.if_scope(tid < elem_num):
r_ind = data[tid]
out[r_ind] = tid
return irb.get()


def invert_permutation(data):
"""Compute definition of invert_permutation.
For an output tensor y and an input tensor x, this operation computes the following:
y[x[i]] = i for i in [0, 1, ..., len(x) - 1]
Parameters
----------
data : tvm.te.Tensor
1-D tensor
Returns
-------
out : tvm.te.Tensor
"""
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8)

out = te.extern(
[data.shape],
[data],
lambda ins, outs: _invert_permutation_ir(ins[0], outs[0]),
in_buffers=[
data_buf,
],
out_buffers=[
out_buf,
],
name="invert_permutation",
tag="invert_permutation_gpu",
)
return out
29 changes: 29 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm import te
from tvm import topi
from tvm.te import hybrid
from . import cpp
from . import tag
from .utils import within_index, make_idx, const_vector
Expand Down Expand Up @@ -941,3 +942,31 @@ def adv_index(data, indices):
Output tensor
"""
return cpp.adv_index(data, indices)


@hybrid.script
def invert_permutation(data):
"""Computes the inverse permutation of data.
Parameters
----------
data : tvm.te.Tensor
Input data
Returns
-------
result : tvm.te.Tensor
Output tensor
Examples
--------
.. code-block:: python
data = [3, 4, 0, 2, 1]
topi.invert_permutation(data) = [2, 4, 3, 0, 1]
"""
result = output_tensor(data.shape, data.dtype)
nums = data.shape[0]
for ind in range(nums):
r_ind = data[ind]
result[r_ind] = ind
return result
18 changes: 18 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4028,5 +4028,23 @@ RELAY_REGISTER_OP("unique")
.add_type_rel("unique", UniqueRel)
.set_support_level(3)
.set_attr<TOpPattern>("TOpPattern", kOpaque);

// invert_permutation
Expr MakeInvertPermutation(Expr data) {
static const Op& op = Op::Get("invert_permutation");
return Call(op, {data}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.invert_permutation").set_body_typed(MakeInvertPermutation);

RELAY_REGISTER_OP("invert_permutation")
.describe(R"doc(Computes the inverse permutation of a tensor.)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("Identity", IdentityRel)
.set_support_level(1)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TOpIsStateful>("TOpIsStateful", false);

} // namespace relay
} // namespace tvm
18 changes: 18 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5585,5 +5585,23 @@ def @main(%A: Tensor[(4, 176, 8, 8), float32]) {
tvm.ir.assert_structural_equal(mod["main"].body, mod_golden["main"].body, map_free_vars=True)


#######################################################################
# invert_permutation
# --------------------


def test_invert_permutation():
"""test InvertPermutation"""
tf.reset_default_graph()

input_shape = [6]
x = np.array([3, 4, 0, 2, 1, 5]).astype("int32")
with tf.Graph().as_default():
in_data = tf.placeholder(shape=input_shape, dtype="int32")
tf.invert_permutation(in_data)
out_name = "InvertPermutation:0"
compare_tf_with_tvm(x, "Placeholder:0", out_name, no_gpu=False)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit dc1f54b

Please sign in to comment.