diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index e42b8bbae814..439d44b5790b 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -332,6 +332,25 @@ def take_shape_func(attrs, inputs, out_ndims): return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] +@_reg.register_legalize("take") +def legalize_dyn_topk(attrs, inputs, types): + """Legalize take op. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current op + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.take_legalize(attrs, inputs, types) + + @script def _argwhere_shape_func_1d(condition): out = output_tensor((2,), "int64") diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index cdf9ce5c9275..6ddbc73e4666 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -426,6 +426,28 @@ def take(a, indices, axis=None, mode="clip"): return cpp.take(a, indices, int(axis), mode) +@tvm.target.generic_func +def take_legalize(attrs, inputs, types): + """Legalizes dyn.topk op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current op + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + if tvm.relay.ty.is_dynamic(types[0]): + return tvm.relay.take(tvm.relay.annotation.stop_fusion(inputs[0]), inputs[1], **attrs) + return None + + def gather(data, axis, indices): """Gather values along given axis from given indices. diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index a3146de55d5a..30ee29525daa 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm from tvm import relay from tvm.relay import transform @@ -757,6 +759,31 @@ def create_diamond_func(inp): assert tvm.ir.structural_equal(fused, expected) +def test_fuse_dynamic_squeeze_slice_take(): + input_data = [ + np.random.random([1, 2, 4]).astype("float32"), + np.array([0]).astype("int64"), + ] + + x = relay.var("p0107", shape=(relay.Any(), relay.Any(), 4), dtype="float32") + take_val = relay.var("p166", shape=(relay.Any(),), dtype="int64") + + squeeze = relay.op.squeeze(x, axis=[0]) + strided_slice = relay.op.strided_slice( + squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1] + ) + take = relay.op.take(strided_slice, take_val, axis=0) + + mod = tvm.IRModule.from_expr(take) + ex = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(), target="llvm") + + result = ex.evaluate()(*input_data) + + np_result = np.squeeze(input_data[0][:, input_data[1][0], :], axis=0) + + assert np.allclose(result.asnumpy(), np_result) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse()