Skip to content

Commit

Permalink
[microNPU] Add MergeConstants pass (#12029)
Browse files Browse the repository at this point in the history
* [microNPU] Add MergeConstants pass

Change-Id: I1ff51d8147fba8c66d442a370b9f058e9b2758d8

* Fix errors and warnings

Change-Id: I29f68f83a73fa00ca34ed0ab2321c53c6b761137

* Address comments

Change-Id: Iad59107d5abdec6b079c6fd4ab48c6bffbb5e0bb

* Fix lint error

Change-Id: Ie5caf506337de01e169d6f422e4682eefbd93241
  • Loading branch information
Nicola Lancellotti authored Jul 12, 2022
1 parent fc419df commit fbf80bb
Show file tree
Hide file tree
Showing 11 changed files with 1,336 additions and 263 deletions.
4 changes: 4 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
mod = ethosu_passes.HoistAllocates()(mod)
# MergeConstant pass currently does not support striped schedules.
# It requires further investigation.
if not util.is_striping_enabled():
mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
mod = ethosu_passes.CopyComputeReordering()(mod)

# When striping is enabled and if storage_rewrite is not run
Expand Down
35 changes: 35 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,3 +938,38 @@ def CopyComputeReordering(max_copy_movements: Optional[int] = None) -> tvm.IRMod
The new module with copy and compute nodes reordered.
"""
return _ffi_api.CopyComputeReordering(max_copy_movements)


def MergeConstants(const_dict):
"""
This pass looks for the constants used by each compute operator
and merges them into a single buffer.
Constants written to a buffer with local scope are not merged.
"""

def _merge_constants(mod):
nonlocal const_dict
try:
mod["main"]
except:
raise tvm.TVMError(
"Expected a single primitive function called 'main'. "
"Please run the MergeConstants pass in conjunction with the LowerToTIR() pass."
)

new_const_dict = {}
for param in const_dict.keys():
new_const_dict[tvm.tir.IntImm("int64", param)] = tvm.nd.array(const_dict[param])
mod["main"] = mod["main"].with_attr("ethos-u.const_dict", new_const_dict)

mod = _ffi_api.MergeConstants()(mod)
const_dict = mod["main"].attrs["ethos-u.const_dict"]
mod = _ffi_api.RemoveConstDictAttribute()(mod)

new_const_dict = {}
for param in const_dict.keys():
new_const_dict[int(param)] = const_dict[param].numpy()

return mod, new_const_dict

return _merge_constants
Loading

0 comments on commit fbf80bb

Please sign in to comment.