Skip to content

Commit

Permalink
[ETHOSU] Add early simplify to fix LoopPartition (apache#9387)
Browse files Browse the repository at this point in the history
* [ETHOSU] Add early simplify to fix LoopPartition

Certain loops aren't correctly partitioned if the loop
condition hasn't been simplified. This can happen when
a copy loop is split by a non-factor. To fix this, an
additional simplify pass is added to the TIR pipeline
prior to LoopPartition.

Change-Id: Icd4ff14648ccaed41384da50c6d183a122b30048

* Fix linting again

Change-Id: I9c9dc2ee2c679861866b23531e88584b94198e51
  • Loading branch information
mbaret authored Nov 1, 2021
1 parent 1f8ef2a commit e807743
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.UnrollLoop()(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = RemoveZeroStores()(mod)
mod = tvm.tir.transform.Simplify()(mod)
Expand Down
64 changes: 63 additions & 1 deletion tests/python/contrib/test_ethosu/test_replace_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm import relay
from tvm.relay.testing import run_opt_pass
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants
from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants, Convolution2DCompute

from .infra import make_ethosu_conv2d

Expand Down Expand Up @@ -73,5 +73,67 @@ def _get_func():
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)


# fmt: off
@tvm.script.ir_module
class WeightStream:
@T.prim_func
def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8")
ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 16], dtype="int8")
buffer = T.match_buffer(placeholder_1, [416], dtype="uint8")
buffer_1 = T.match_buffer(placeholder_2, [112], dtype="uint8")
buffer_2 = T.match_buffer(placeholder_3, [272], dtype="uint8")
buffer_3 = T.match_buffer(placeholder_4, [64], dtype="uint8")
# body
placeholder_global = T.allocate([416], "uint8", "global")
placeholder_d_global = T.allocate([112], "uint8", "global")
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 272, T.load("uint8", placeholder_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 64, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, T.load("int8", ethosu_write_1.data, 10), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 12, T.load("uint8", placeholder_d_global, 0), 64, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle"))
__tvm_meta__ = None
# fmt: on


def test_weight_stream():
def _cascader(cached_func, const_dict, sch):
weight = cached_func.inputs[1]
scale_bias = cached_func.inputs[2]
out = cached_func.outputs[0]
conv_compute = Convolution2DCompute.from_output(out)
co = conv_compute.split(sch, 3, 10)
cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d])
cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.conv2d])
sch[cache_weight].compute_at(sch[out], co)
sch[cache_scale_bias].compute_at(sch[out], co)

def _get_func():
ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8")
conv = make_ethosu_conv2d(
ifm,
32,
16,
(1, 1),
(0, 0),
(1, 1),
(1, 1),
)
func = relay.Function(relay.analysis.free_vars(conv), conv)
func = run_opt_pass(func, relay.transform.InferType())
return func

func = _get_func()
mod, _ = lower_to_tir(func, cascader=_cascader)

script = mod.script(show_meta=True)
test_mod = tvm.script.from_source(script)
reference_mod = WeightStream
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit e807743

Please sign in to comment.