diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 2eec6d03e7cd..95fb2ad18a25 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -23,6 +23,7 @@ schedule_injective = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective +schedule_concatenate = _reg.schedule_concatenate _reg.register_schedule("collapse_sum_like", _schedule_reduce) @@ -46,7 +47,7 @@ _reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("where", schedule_broadcast) _reg.register_schedule("stack", schedule_injective) -_reg.register_schedule("concatenate", schedule_injective) +_reg.register_schedule("concatenate", schedule_concatenate) _reg.register_schedule("_contrib_reverse_reshape", schedule_injective) _reg.register_schedule("gather_nd", schedule_injective) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 6ba207934d1b..906bf255d46e 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -219,6 +219,13 @@ def schedule_injective(attrs, outputs, target): with target: return topi.generic.schedule_injective(outputs) + +def schedule_concatenate(attrs, outputs, target): + """Generic schedule for concatinate.""" + with target: + return topi.generic.schedule_concatenate(outputs) + + __DEBUG_COUNTER__ = 0 def debug(expr, debug_func=None): diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 9afdc32cf117..028558f69e91 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -51,3 +51,32 @@ def schedule_injective(outs): elif len(s[x].op.axis) >= 2: s[x].parallel(s[x].op.axis[0]) return s + +@generic.schedule_concatenate.register(["arm_cpu"]) +def schedule_concatenate(outs): + """Schedule for concatenate op. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of reduce in the format + of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + x = outs[0] + tvm.schedule.AutoInlineInjective(s) + if len(s[x].op.axis) >= 4: + fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2]) + s[x].parallel(fused) + elif len(s[x].op.axis) >= 3: + fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1]) + s[x].parallel(fused) + elif len(s[x].op.axis) >= 2: + s[x].parallel(s[x].op.axis[0]) + return s diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index a078eacae85b..d29fb64544b9 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -127,7 +127,7 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): - s = topi.generic.schedule_injective(out_tensor) + s = topi.generic.schedule_concatenate(out_tensor) foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate") data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes] @@ -476,6 +476,7 @@ def test_concatenate(): (12, 6, 7, 3), (8, 6, 7, 3), (2, 6, 7, 3)], 0) + verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1) def test_stack():