From 3cd5b39da3895a56d37c97f768a7043fdddac839 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Tue, 31 Aug 2021 11:41:42 +0100 Subject: [PATCH 1/2] [ETHOSU] Add early simplify to fix LoopPartition Certain loops aren't correctly partitioned if the loop condition hasn't been simplified. This can happen when a copy loop is split by a non-factor. To fix this, an additional simplify pass is added to the TIR pipeline prior to LoopPartition. Change-Id: Icd4ff14648ccaed41384da50c6d183a122b30048 --- .../backend/contrib/ethosu/tir/compiler.py | 1 + .../contrib/test_ethosu/test_replace_copy.py | 62 ++++++++++++++++++- 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index c792ade06643..bc95a9a3bab7 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -78,6 +78,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.UnrollLoop()(mod) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.LoopPartition()(mod) mod = RemoveZeroStores()(mod) mod = tvm.tir.transform.Simplify()(mod) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 76b7ef2a70ee..0ef933f730e9 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir -from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants +from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants, Convolution2DCompute from .infra import make_ethosu_conv2d @@ -73,5 +73,65 @@ def _get_func(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) +@tvm.script.ir_module +class WeightStream: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8") + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 16], dtype="int8") + buffer = T.match_buffer(placeholder_1, [416], dtype="uint8") + buffer_1 = T.match_buffer(placeholder_2, [112], dtype="uint8") + buffer_2 = T.match_buffer(placeholder_3, [272], dtype="uint8") + buffer_3 = T.match_buffer(placeholder_4, [64], dtype="uint8") + # body + placeholder_global = T.allocate([416], "uint8", "global") + placeholder_d_global = T.allocate([112], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 272, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 64, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, T.load("int8", ethosu_write_1.data, 10), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 12, T.load("uint8", placeholder_d_global, 0), 64, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +def test_weight_stream(): + def _cascader(cached_func, const_dict, sch): + weight = cached_func.inputs[1] + scale_bias = cached_func.inputs[2] + out = cached_func.outputs[0] + conv_compute = Convolution2DCompute.from_output(out) + co = conv_compute.split(sch, 3, 10) + cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d]) + cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.conv2d]) + sch[cache_weight].compute_at(sch[out], co) + sch[cache_scale_bias].compute_at(sch[out], co) + + def _get_func(): + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") + conv = make_ethosu_conv2d( + ifm, + 32, + 16, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, _ = lower_to_tir(func, cascader=_cascader) + + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + reference_mod = WeightStream + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + if __name__ == "__main__": pytest.main([__file__]) From 7ec99a7ea518cd28cffb7ab51af308fc7f6524fb Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Fri, 29 Oct 2021 17:03:37 +0100 Subject: [PATCH 2/2] Fix linting again Change-Id: I9c9dc2ee2c679861866b23531e88584b94198e51 --- tests/python/contrib/test_ethosu/test_replace_copy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 0ef933f730e9..9590db57dd32 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -73,6 +73,7 @@ def _get_func(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) +# fmt: off @tvm.script.ir_module class WeightStream: @T.prim_func @@ -95,6 +96,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 64, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, T.load("int8", ethosu_write_1.data, 10), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 12, T.load("uint8", placeholder_d_global, 0), 64, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None +# fmt: on def test_weight_stream():