From fcf4c83ae8018833e208f38a778b5dcbf51bfa41 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 11 Nov 2020 09:43:41 -0700 Subject: [PATCH 1/4] add a regression test for fusing dynamic take --- tests/python/relay/test_pass_fuse_ops.py | 27 ++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index a3146de55d5a..30ee29525daa 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm from tvm import relay from tvm.relay import transform @@ -757,6 +759,31 @@ def create_diamond_func(inp): assert tvm.ir.structural_equal(fused, expected) +def test_fuse_dynamic_squeeze_slice_take(): + input_data = [ + np.random.random([1, 2, 4]).astype("float32"), + np.array([0]).astype("int64"), + ] + + x = relay.var("p0107", shape=(relay.Any(), relay.Any(), 4), dtype="float32") + take_val = relay.var("p166", shape=(relay.Any(),), dtype="int64") + + squeeze = relay.op.squeeze(x, axis=[0]) + strided_slice = relay.op.strided_slice( + squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1] + ) + take = relay.op.take(strided_slice, take_val, axis=0) + + mod = tvm.IRModule.from_expr(take) + ex = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(), target="llvm") + + result = ex.evaluate()(*input_data) + + np_result = np.squeeze(input_data[0][:, input_data[1][0], :], axis=0) + + assert np.allclose(result.asnumpy(), np_result) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() From f6fc702f4dbf507ef23838f3fa6d3812a49dbeff Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 25 Nov 2020 15:59:59 -0700 Subject: [PATCH 2/4] add legalize for take that stops fusion on dynamic inputs --- python/tvm/relay/op/_transform.py | 19 +++++++++++++++++++ python/tvm/topi/transform.py | 22 ++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index e42b8bbae814..7843f840824c 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -332,6 +332,25 @@ def take_shape_func(attrs, inputs, out_ndims): return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] +@_reg.register_legalize("take") +def legalize_dyn_topk(attrs, inputs, types): + """Legalize take op. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.take_legalize(attrs, inputs, types) + + @script def _argwhere_shape_func_1d(condition): out = output_tensor((2,), "int64") diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index cdf9ce5c9275..353b587da70c 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -426,6 +426,28 @@ def take(a, indices, axis=None, mode="clip"): return cpp.take(a, indices, int(axis), mode) +@tvm.target.generic_func +def take_legalize(attrs, inputs, types): + """Legalizes dyn.topk op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + if tvm.relay.ty.is_dynamic(inputs[0].checked_type): + return tvm.relay.take(tvm.relay.annotation.stop_fusion(inputs[0]), inputs[1], **attrs) + return None + + def gather(data, axis, indices): """Gather values along given axis from given indices. From 7a740702f1f9709e9e709970a2ce7deb979e9b45 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 25 Nov 2020 16:08:17 -0700 Subject: [PATCH 3/4] fix lint --- python/tvm/topi/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 353b587da70c..9f0b4a2cba98 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -443,7 +443,7 @@ def take_legalize(attrs, inputs, types): result : tvm.relay.Expr The legalized expr """ - if tvm.relay.ty.is_dynamic(inputs[0].checked_type): + if tvm.relay.ty.is_dynamic(types[0]): return tvm.relay.take(tvm.relay.annotation.stop_fusion(inputs[0]), inputs[1], **attrs) return None From 45a4ccd59386936687222c19e84f474e9d57a6e7 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 25 Nov 2020 17:54:48 -0700 Subject: [PATCH 4/4] fix typo --- python/tvm/relay/op/_transform.py | 2 +- python/tvm/topi/transform.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 7843f840824c..439d44b5790b 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -338,7 +338,7 @@ def legalize_dyn_topk(attrs, inputs, types): Parameters ---------- attrs : tvm.ir.Attrs - Attributes of current convolution + Attributes of current op inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized types : list of types diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 9f0b4a2cba98..6ddbc73e4666 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -433,7 +433,7 @@ def take_legalize(attrs, inputs, types): Parameters ---------- attrs : tvm.ir.Attrs - Attributes of current convolution + Attributes of current op inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized types : list of types