Skip to content

Commit

Permalink
[microNPU] Fixed MergeConstant pass on striped networks
Browse files Browse the repository at this point in the history
  • Loading branch information
sergio-grovety committed Nov 7, 2022
1 parent 60e2c98 commit fd1f6d7
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 13 deletions.
5 changes: 1 addition & 4 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
mod = ethosu_passes.HoistAllocates()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
# MergeConstant pass currently does not support striped schedules.
# It requires further investigation.
if not util.is_striping_enabled():
mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
mod = ethosu_passes.CopyComputeReordering()(mod)

# When striping is enabled and if storage_rewrite is not run
Expand Down
27 changes: 18 additions & 9 deletions src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ class MergeConstantsMutator : public StmtExprMutator {

// Make the new const dict
Array<Array<IntImm>> args_to_merge{GetArgsToMerge(main_func->buffer_map, main_func->params)};
Array<Array<IntImm>> buffers_to_merge{
Map<IntImm, Array<IntImm>> buffers_to_merge{
GetArgsToMergeWithoutArgsNotInConstDict(args_to_merge, const_dict)};
Map<IntImm, runtime::NDArray> new_const_dict{MakeNewConstDict(buffers_to_merge, const_dict)};

Expand Down Expand Up @@ -832,31 +832,41 @@ class MergeConstantsMutator : public StmtExprMutator {
return vector;
}

Array<Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
Map<IntImm, Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
const Array<Array<IntImm>>& args_to_merge, const Map<IntImm, runtime::NDArray>& const_dict) {
Array<Array<IntImm>> new_args_to_merge{};
Map<IntImm, Array<IntImm>> new_args_to_merge{};
bool first_arg_found = false;
int64_t new_arg_key = 0; // the updated key of the merged const_dict
for (Array<IntImm> args : args_to_merge) {
IntImm key{args[0]};
auto it = std::find_if(const_dict.begin(), const_dict.end(),
[&](std::pair<tvm::IntImm, runtime::NDArray> pair) {
return pair.first->value == key->value;
});
if (it != const_dict.end()) {
new_args_to_merge.push_back(args);
if (first_arg_found == false) {
first_arg_found = true;
new_arg_key = key->value;
}
new_args_to_merge.Set(IntImm(DataType::Int(64), new_arg_key), args);
}
if (first_arg_found) {
new_arg_key++;
}
}
return new_args_to_merge;
}

Map<IntImm, runtime::NDArray> MakeNewConstDict(const Array<Array<IntImm>>& args_to_merge,
Map<IntImm, runtime::NDArray> MakeNewConstDict(const Map<IntImm, Array<IntImm>>& args_to_merge,
Map<IntImm, runtime::NDArray> const_dict) {
Map<IntImm, runtime::NDArray> new_const_dict{};
if (args_to_merge.size() == 0) {
return new_const_dict;
}

int64_t key = args_to_merge[0][0]->value;
for (Array<IntImm> args : args_to_merge) {
for (auto const& elem : args_to_merge) {
IntImm key = elem.first;
Array<IntImm> args = elem.second;
int64_t size = 0;
for (IntImm arg : args) {
auto it = std::find_if(const_dict.begin(), const_dict.end(),
Expand All @@ -876,8 +886,7 @@ class MergeConstantsMutator : public StmtExprMutator {
arg_constant.CopyToBytes(static_cast<uint8_t*>(constant->data) + offset, nbytes);
offset += nbytes;
}
new_const_dict.Set(IntImm(DataType::Int(64), key), constant);
key += 1;
new_const_dict.Set(key, constant);
}
return new_const_dict;
}
Expand Down
189 changes: 189 additions & 0 deletions tests/python/contrib/test_ethosu/test_merge_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,195 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint
check_const_dictionaries(const_dict, new_const_dict)


def test_arbitrary_argument_order():
# fmt: off
@tvm.script.ir_module
class InputModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# buffer definition
T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
# body
p1_data = T.allocate([368], "uint8", "global")
p1 = T.buffer_decl([368], "uint8", data=p1_data)
p2_data = T.allocate([96], "uint8", "global")
p2 = T.buffer_decl([96], "uint8", data=p2_data)
p3_data = T.allocate([368], "uint8", "global")
p3 = T.buffer_decl([368], "uint8", data=p3_data)
p4_data = T.allocate([96], "uint8", "global")
p4 = T.buffer_decl([96], "uint8", data=p4_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None


@tvm.script.ir_module
class ReferenceModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# body
p1_data = T.allocate([464], "uint8", "global")
p1 = T.buffer_decl([464], "uint8", data=p1_data)
p2_data = T.allocate([464], "uint8", "global")
p2 = T.buffer_decl([464], "uint8", data=p2_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None
# fmt: on

const_dict = {
1: np.array([1], dtype=np.uint8),
2: np.array([2], dtype=np.uint8),
4: np.array([4], dtype=np.uint8),
5: np.array([5], dtype=np.uint8),
}
new_const_dict = {
1: np.concatenate((const_dict[1], const_dict[2])),
3: np.concatenate((const_dict[4], const_dict[5])),
}
test_mod, const_dict = MergeConstants(const_dict)(InputModule)
reference_mod = ReferenceModule
tvm.ir.assert_structural_equal(test_mod, reference_mod, False)
check_const_dictionaries(const_dict, new_const_dict)


def test_arbitrary_argument_order_const_split():
# fmt: off
@tvm.script.ir_module
class InputModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(96,), "uint8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# buffer definition
T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
# body
p1_data = T.allocate([368], "uint8", "global")
p1 = T.buffer_decl([368], "uint8", data=p1_data)
p2_data = T.allocate([96], "uint8", "global")
p2 = T.buffer_decl([96], "uint8", data=p2_data)
p3_data = T.allocate([368], "uint8", "global")
p3 = T.buffer_decl([368], "uint8", data=p3_data)
p4_data = T.allocate([96], "uint8", "global")
p4 = T.buffer_decl([96], "uint8", data=p4_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None


@tvm.script.ir_module
class ReferenceModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# body
p1_data = T.allocate([464], "uint8", "global")
p1 = T.buffer_decl([464], "uint8", data=p1_data)
p2_data = T.allocate([464], "uint8", "global")
p2 = T.buffer_decl([464], "uint8", data=p2_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None
# fmt: on

const_dict = {
1: np.array([1], dtype=np.uint8),
3: np.array([3], dtype=np.uint8),
4: np.array([4], dtype=np.uint8),
5: np.array([5], dtype=np.uint8),
}
new_const_dict = {
1: np.concatenate((const_dict[1], const_dict[3])),
3: np.concatenate((const_dict[4], const_dict[5])),
}
test_mod, const_dict = MergeConstants(const_dict)(InputModule)
reference_mod = ReferenceModule
tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
check_const_dictionaries(const_dict, new_const_dict)


def test_arbitrary_argument_order_const_split_mixed():
# fmt: off
@tvm.script.ir_module
class InputModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# buffer definition
T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
# body
p1_data = T.allocate([368], "uint8", "global")
p1 = T.buffer_decl([368], "uint8", data=p1_data)
p2_data = T.allocate([368], "uint8", "global")
p2 = T.buffer_decl([368], "uint8", data=p2_data)
p3_data = T.allocate([96], "uint8", "global")
p3 = T.buffer_decl([96], "uint8", data=p3_data)
p4_data = T.allocate([96], "uint8", "global")
p4 = T.buffer_decl([96], "uint8", data=p4_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 96, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p3[0], 48, p3[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 368, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None


@tvm.script.ir_module
class ReferenceModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], buffer2: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# body
p1_data = T.allocate([464], "uint8", "global")
p1 = T.buffer_decl([464], "uint8", data=p1_data)
p2_data = T.allocate([464], "uint8", "global")
p2 = T.buffer_decl([464], "uint8", data=p2_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None
# fmt: on

const_dict = {
1: np.array([1], dtype=np.uint8),
2: np.array([2], dtype=np.uint8),
4: np.array([4], dtype=np.uint8),
5: np.array([5], dtype=np.uint8),
}
new_const_dict = {
1: np.concatenate((const_dict[1], const_dict[4])),
2: np.concatenate((const_dict[2], const_dict[5])),
}
test_mod, const_dict = MergeConstants(const_dict)(InputModule)
reference_mod = ReferenceModule
tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
check_const_dictionaries(const_dict, new_const_dict)


def test_cycle_count():
# fmt: off
@tvm.script.ir_module
Expand Down

0 comments on commit fd1f6d7

Please sign in to comment.