Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorFlow][Frontend] Adding InversePermutation Op #8277

Merged
merged 7 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,6 +2886,7 @@ def _impl(inputs, attr, params, mod):
"GreaterEqual": _broadcast("greater_equal"),
"Identity": _identity(),
"IdentityN": _identityn(),
"InvertPermutation": AttrCvt("invert_permutation"),
"IsFinite": AttrCvt("isfinite"),
"IsInf": AttrCvt("isinf"),
"IsNan": AttrCvt("isnan"),
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def compute_unique(attrs, inputs, output_type):

_reg.register_strategy("unique", strategy.unique_strategy)

# invert_permutation
_reg.register_strategy("invert_permutation", strategy.invert_permutation_strategy)
_reg.register_shape_func("invert_permutation", False, elemwise_shape_func)

#####################
# Shape functions #
#####################
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,3 +1135,15 @@ def schedule_transpose_cuda(attrs, outs, target):
):
return topi.cuda.schedule_transpose(outs)
return schedule_injective(attrs, outs, target)


@invert_permutation_strategy.register(["cuda", "gpu"])
def invert_permutation_strategy_cuda(attrs, inputs, out_type, target):
"""invert_permutation cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_invert_permutation(topi.cuda.invert_permutation),
wrap_topi_schedule(topi.cuda.vision._default_schedule),
name="invert_permutation.cuda",
)
return strategy
22 changes: 22 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,3 +1607,25 @@ def schedule_transpose(attrs, outs, target):
"""schedule transpose"""
with target:
return schedule_injective(attrs, outs, target)


# invert_permutation
def wrap_compute_invert_permutation(topi_compute):
"""wrap invert_permutation topi compute"""

def _compute_invert_permutation(attrs, inputs, out_type):
return [topi_compute(inputs[0])]

return _compute_invert_permutation


@override_native_generic_func("invert_permutation_strategy")
def invert_permutation_strategy(attrs, inputs, out_type, target):
"""invert_permutation generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_invert_permutation(topi.invert_permutation),
wrap_topi_schedule(topi.generic.schedule_injective),
name="invert_permutation.generic",
)
return strategy
28 changes: 28 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,3 +1720,31 @@ def unique(data, is_sorted=True, return_counts=False):
if return_counts:
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 5)
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4)


def invert_permutation(data):
"""Computes the inverse permutation of data.
This operation computes the inverse of an index permutation.
It takes a 1-D integer tensor x, which represents the indices of a zero-based
array and swaps each value with its index position.
For an output tensor y and an input tensor x, this operation computes the following:
y[x[i]] = i for i in [0, 1, ..., len(x) - 1]
Parameters
----------
data : relay.Expr
The source data to be invert permuated.
Returns
-------
ret : relay.Expr
Invert permuated data. Has the same type as data.
Examples
--------
.. code-block:: python
data = [3, 4, 0, 2, 1]
relay.invert_permutation(data) = [2, 4, 3, 0, 1]
"""
return _make.invert_permutation(data)
73 changes: 72 additions & 1 deletion python/tvm/topi/cuda/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""CUDA implementations of transforms"""

import tvm
from ... import te
from ...target import Target
from ..utils import traverse_inline
Expand Down Expand Up @@ -65,3 +65,74 @@ def _callback(op):
s[c].bind(ao, thread_y)

traverse_inline(s, out.op, _callback)


def _invert_permutation_ir(data, out):
"""Low level IR to get invert_permutation.
Parameters
----------
data : Buffer
Input data. 1-D Buffer with shape [elem_num].
out : Buffer
1D buffer for invert permutation result with the same shape with data.
Returns
-------
stmt : Stmt
The result IR statement.
"""
elem_num = data.shape[0]

irb = tvm.tir.ir_builder.create()
data = irb.buffer_ptr(data)
out = irb.buffer_ptr(out)

max_threads = int(Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = elem_num // max_threads + 1
thread_x = te.thread_axis("threadIdx.x")
block_x = te.thread_axis("blockIdx.x")
irb.scope_attr(thread_x, "thread_extent", nthread_tx)
irb.scope_attr(block_x, "thread_extent", nthread_bx)
tid = block_x * max_threads + thread_x

with irb.if_scope(tid < elem_num):
r_ind = data[tid]
out[r_ind] = tid
return irb.get()


def invert_permutation(data):
"""Compute definition of invert_permutation.
For an output tensor y and an input tensor x, this operation computes the following:
y[x[i]] = i for i in [0, 1, ..., len(x) - 1]
Parameters
----------
data : tvm.te.Tensor
1-D tensor
Returns
-------
out : tvm.te.Tensor
"""
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8)

out = te.extern(
[data.shape],
[data],
lambda ins, outs: _invert_permutation_ir(ins[0], outs[0]),
in_buffers=[
data_buf,
],
out_buffers=[
out_buf,
],
name="invert_permutation",
tag="invert_permutation_gpu",
)
return out
29 changes: 29 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm import te
from tvm import topi
from tvm.te import hybrid
from . import cpp
from . import tag
from .utils import within_index, make_idx, const_vector
Expand Down Expand Up @@ -941,3 +942,31 @@ def adv_index(data, indices):
Output tensor
"""
return cpp.adv_index(data, indices)


@hybrid.script
def invert_permutation(data):
"""Computes the inverse permutation of data.
Parameters
----------
data : tvm.te.Tensor
Input data
Returns
-------
result : tvm.te.Tensor
Output tensor
Examples
--------
.. code-block:: python
data = [3, 4, 0, 2, 1]
topi.invert_permutation(data) = [2, 4, 3, 0, 1]
"""
result = output_tensor(data.shape, data.dtype)
nums = data.shape[0]
for ind in range(nums):
r_ind = data[ind]
result[r_ind] = ind
return result
18 changes: 18 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3976,5 +3976,23 @@ RELAY_REGISTER_OP("unique")
.add_type_rel("unique", UniqueRel)
.set_support_level(3)
.set_attr<TOpPattern>("TOpPattern", kOpaque);

// invert_permutation
Expr MakeInvertPermutation(Expr data) {
static const Op& op = Op::Get("invert_permutation");
return Call(op, {data}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.invert_permutation").set_body_typed(MakeInvertPermutation);

RELAY_REGISTER_OP("invert_permutation")
.describe(R"doc(Computes the inverse permutation of a tensor.)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("Identity", IdentityRel)
.set_support_level(1)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TOpIsStateful>("TOpIsStateful", false);

} // namespace relay
} // namespace tvm
18 changes: 18 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5569,5 +5569,23 @@ def @main(%A: Tensor[(4, 176, 8, 8), float32]) {
tvm.ir.assert_structural_equal(mod["main"].body, mod_golden["main"].body, map_free_vars=True)


#######################################################################
# invert_permutation
# --------------------


def test_invert_permutation():
"""test InvertPermutation"""
tf.reset_default_graph()

input_shape = [6]
x = np.array([3, 4, 0, 2, 1, 5]).astype("int32")
with tf.Graph().as_default():
in_data = tf.placeholder(shape=input_shape, dtype="int32")
tf.invert_permutation(in_data)
out_name = "InvertPermutation:0"
compare_tf_with_tvm(x, "Placeholder:0", out_name, no_gpu=False)


if __name__ == "__main__":
pytest.main([__file__])