diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 3c4a9b69ea6e..be15f83faf0f 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -2886,6 +2886,7 @@ def _impl(inputs, attr, params, mod): "GreaterEqual": _broadcast("greater_equal"), "Identity": _identity(), "IdentityN": _identityn(), + "InvertPermutation": AttrCvt("invert_permutation"), "IsFinite": AttrCvt("isfinite"), "IsInf": AttrCvt("isinf"), "IsNan": AttrCvt("isnan"), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index f87b5ed0b8ef..bee188f19364 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -178,6 +178,10 @@ def compute_unique(attrs, inputs, output_type): _reg.register_strategy("unique", strategy.unique_strategy) +# invert_permutation +_reg.register_strategy("invert_permutation", strategy.invert_permutation_strategy) +_reg.register_shape_func("invert_permutation", False, elemwise_shape_func) + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index b4db412700a7..6418f1f96b3b 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1135,3 +1135,15 @@ def schedule_transpose_cuda(attrs, outs, target): ): return topi.cuda.schedule_transpose(outs) return schedule_injective(attrs, outs, target) + + +@invert_permutation_strategy.register(["cuda", "gpu"]) +def invert_permutation_strategy_cuda(attrs, inputs, out_type, target): + """invert_permutation cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_invert_permutation(topi.cuda.invert_permutation), + wrap_topi_schedule(topi.cuda.vision._default_schedule), + name="invert_permutation.cuda", + ) + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index d56820e409aa..35e5177458a5 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1607,3 +1607,25 @@ def schedule_transpose(attrs, outs, target): """schedule transpose""" with target: return schedule_injective(attrs, outs, target) + + +# invert_permutation +def wrap_compute_invert_permutation(topi_compute): + """wrap invert_permutation topi compute""" + + def _compute_invert_permutation(attrs, inputs, out_type): + return [topi_compute(inputs[0])] + + return _compute_invert_permutation + + +@override_native_generic_func("invert_permutation_strategy") +def invert_permutation_strategy(attrs, inputs, out_type, target): + """invert_permutation generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_invert_permutation(topi.invert_permutation), + wrap_topi_schedule(topi.generic.schedule_injective), + name="invert_permutation.generic", + ) + return strategy diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 440d2fae042f..9cb50ed6548a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1720,3 +1720,31 @@ def unique(data, is_sorted=True, return_counts=False): if return_counts: return TupleWrapper(_make.unique(data, is_sorted, return_counts), 5) return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4) + + +def invert_permutation(data): + """Computes the inverse permutation of data. + This operation computes the inverse of an index permutation. + It takes a 1-D integer tensor x, which represents the indices of a zero-based + array and swaps each value with its index position. + + For an output tensor y and an input tensor x, this operation computes the following: + y[x[i]] = i for i in [0, 1, ..., len(x) - 1] + + Parameters + ---------- + data : relay.Expr + The source data to be invert permuated. + + Returns + ------- + ret : relay.Expr + Invert permuated data. Has the same type as data. + + Examples + -------- + .. code-block:: python + data = [3, 4, 0, 2, 1] + relay.invert_permutation(data) = [2, 4, 3, 0, 1] + """ + return _make.invert_permutation(data) diff --git a/python/tvm/topi/cuda/transform.py b/python/tvm/topi/cuda/transform.py index 89caf94bbbc1..16b1273def47 100644 --- a/python/tvm/topi/cuda/transform.py +++ b/python/tvm/topi/cuda/transform.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """CUDA implementations of transforms""" - +import tvm from ... import te from ...target import Target from ..utils import traverse_inline @@ -65,3 +65,74 @@ def _callback(op): s[c].bind(ao, thread_y) traverse_inline(s, out.op, _callback) + + +def _invert_permutation_ir(data, out): + """Low level IR to get invert_permutation. + + Parameters + ---------- + data : Buffer + Input data. 1-D Buffer with shape [elem_num]. + + out : Buffer + 1D buffer for invert permutation result with the same shape with data. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + elem_num = data.shape[0] + + irb = tvm.tir.ir_builder.create() + data = irb.buffer_ptr(data) + out = irb.buffer_ptr(out) + + max_threads = int(Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = elem_num // max_threads + 1 + thread_x = te.thread_axis("threadIdx.x") + block_x = te.thread_axis("blockIdx.x") + irb.scope_attr(thread_x, "thread_extent", nthread_tx) + irb.scope_attr(block_x, "thread_extent", nthread_bx) + tid = block_x * max_threads + thread_x + + with irb.if_scope(tid < elem_num): + r_ind = data[tid] + out[r_ind] = tid + return irb.get() + + +def invert_permutation(data): + """Compute definition of invert_permutation. + For an output tensor y and an input tensor x, this operation computes the following: + + y[x[i]] = i for i in [0, 1, ..., len(x) - 1] + + Parameters + ---------- + data : tvm.te.Tensor + 1-D tensor + + Returns + ------- + out : tvm.te.Tensor + """ + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) + + out = te.extern( + [data.shape], + [data], + lambda ins, outs: _invert_permutation_ir(ins[0], outs[0]), + in_buffers=[ + data_buf, + ], + out_buffers=[ + out_buf, + ], + name="invert_permutation", + tag="invert_permutation_gpu", + ) + return out diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index b4d0167be2b1..45756eadbcdb 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -20,6 +20,7 @@ import tvm from tvm import te from tvm import topi +from tvm.te import hybrid from . import cpp from . import tag from .utils import within_index, make_idx, const_vector @@ -941,3 +942,31 @@ def adv_index(data, indices): Output tensor """ return cpp.adv_index(data, indices) + + +@hybrid.script +def invert_permutation(data): + """Computes the inverse permutation of data. + + Parameters + ---------- + data : tvm.te.Tensor + Input data + + Returns + ------- + result : tvm.te.Tensor + Output tensor + + Examples + -------- + .. code-block:: python + data = [3, 4, 0, 2, 1] + topi.invert_permutation(data) = [2, 4, 3, 0, 1] + """ + result = output_tensor(data.shape, data.dtype) + nums = data.shape[0] + for ind in range(nums): + r_ind = data[ind] + result[r_ind] = ind + return result diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 7d40bf22bcee..5dc2a677f13f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3976,5 +3976,23 @@ RELAY_REGISTER_OP("unique") .add_type_rel("unique", UniqueRel) .set_support_level(3) .set_attr("TOpPattern", kOpaque); + +// invert_permutation +Expr MakeInvertPermutation(Expr data) { + static const Op& op = Op::Get("invert_permutation"); + return Call(op, {data}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.invert_permutation").set_body_typed(MakeInvertPermutation); + +RELAY_REGISTER_OP("invert_permutation") + .describe(R"doc(Computes the inverse permutation of a tensor.)doc" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("Identity", IdentityRel) + .set_support_level(1) + .set_attr("TOpPattern", kInjective) + .set_attr("TOpIsStateful", false); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 57497d04706a..0ef3317525b3 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -5569,5 +5569,23 @@ def @main(%A: Tensor[(4, 176, 8, 8), float32]) { tvm.ir.assert_structural_equal(mod["main"].body, mod_golden["main"].body, map_free_vars=True) +####################################################################### +# invert_permutation +# -------------------- + + +def test_invert_permutation(): + """test InvertPermutation""" + tf.reset_default_graph() + + input_shape = [6] + x = np.array([3, 4, 0, 2, 1, 5]).astype("int32") + with tf.Graph().as_default(): + in_data = tf.placeholder(shape=input_shape, dtype="int32") + tf.invert_permutation(in_data) + out_name = "InvertPermutation:0" + compare_tf_with_tvm(x, "Placeholder:0", out_name, no_gpu=False) + + if __name__ == "__main__": pytest.main([__file__])