Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microNPU] Fixed MergeConstants pass on striped networks #13281

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
mod = ethosu_passes.HoistAllocates()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
# MergeConstant pass currently does not support striped schedules.
# It requires further investigation.
if not util.is_striping_enabled():
mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
mod = ethosu_passes.CopyComputeReordering()(mod)

# When striping is enabled and if storage_rewrite is not run
Expand Down
27 changes: 18 additions & 9 deletions src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ class MergeConstantsMutator : public StmtExprMutator {

// Make the new const dict
Array<Array<IntImm>> args_to_merge{GetArgsToMerge(main_func->buffer_map, main_func->params)};
Array<Array<IntImm>> buffers_to_merge{
Map<IntImm, Array<IntImm>> buffers_to_merge{
GetArgsToMergeWithoutArgsNotInConstDict(args_to_merge, const_dict)};
Map<IntImm, runtime::NDArray> new_const_dict{MakeNewConstDict(buffers_to_merge, const_dict)};

Expand Down Expand Up @@ -832,31 +832,41 @@ class MergeConstantsMutator : public StmtExprMutator {
return vector;
}

Array<Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
Map<IntImm, Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
const Array<Array<IntImm>>& args_to_merge, const Map<IntImm, runtime::NDArray>& const_dict) {
Array<Array<IntImm>> new_args_to_merge{};
Map<IntImm, Array<IntImm>> new_args_to_merge{};
bool first_arg_found = false;
int64_t new_arg_key = 0; // the updated key of the merged const_dict
for (Array<IntImm> args : args_to_merge) {
IntImm key{args[0]};
auto it = std::find_if(const_dict.begin(), const_dict.end(),
[&](std::pair<tvm::IntImm, runtime::NDArray> pair) {
return pair.first->value == key->value;
});
if (it != const_dict.end()) {
new_args_to_merge.push_back(args);
if (first_arg_found == false) {
first_arg_found = true;
new_arg_key = key->value;
}
new_args_to_merge.Set(IntImm(DataType::Int(64), new_arg_key), args);
}
if (first_arg_found) {
new_arg_key++;
}
}
return new_args_to_merge;
}

Map<IntImm, runtime::NDArray> MakeNewConstDict(const Array<Array<IntImm>>& args_to_merge,
Map<IntImm, runtime::NDArray> MakeNewConstDict(const Map<IntImm, Array<IntImm>>& args_to_merge,
Map<IntImm, runtime::NDArray> const_dict) {
Map<IntImm, runtime::NDArray> new_const_dict{};
if (args_to_merge.size() == 0) {
return new_const_dict;
}

int64_t key = args_to_merge[0][0]->value;
for (Array<IntImm> args : args_to_merge) {
for (auto const& elem : args_to_merge) {
IntImm key = elem.first;
Array<IntImm> args = elem.second;
int64_t size = 0;
for (IntImm arg : args) {
auto it = std::find_if(const_dict.begin(), const_dict.end(),
Expand All @@ -876,8 +886,7 @@ class MergeConstantsMutator : public StmtExprMutator {
arg_constant.CopyToBytes(static_cast<uint8_t*>(constant->data) + offset, nbytes);
offset += nbytes;
}
new_const_dict.Set(IntImm(DataType::Int(64), key), constant);
key += 1;
new_const_dict.Set(key, constant);
}
return new_const_dict;
}
Expand Down
189 changes: 189 additions & 0 deletions tests/python/contrib/test_ethosu/test_merge_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,195 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint
check_const_dictionaries(const_dict, new_const_dict)


def test_arbitrary_argument_order():
# fmt: off
@tvm.script.ir_module
class InputModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# buffer definition
T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
# body
p1_data = T.allocate([368], "uint8", "global")
p1 = T.buffer_decl([368], "uint8", data=p1_data)
p2_data = T.allocate([96], "uint8", "global")
p2 = T.buffer_decl([96], "uint8", data=p2_data)
p3_data = T.allocate([368], "uint8", "global")
p3 = T.buffer_decl([368], "uint8", data=p3_data)
p4_data = T.allocate([96], "uint8", "global")
p4 = T.buffer_decl([96], "uint8", data=p4_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None


@tvm.script.ir_module
class ReferenceModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# body
p1_data = T.allocate([464], "uint8", "global")
p1 = T.buffer_decl([464], "uint8", data=p1_data)
p2_data = T.allocate([464], "uint8", "global")
p2 = T.buffer_decl([464], "uint8", data=p2_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None
# fmt: on

const_dict = {
1: np.array([1], dtype=np.uint8),
2: np.array([2], dtype=np.uint8),
4: np.array([4], dtype=np.uint8),
5: np.array([5], dtype=np.uint8),
}
new_const_dict = {
1: np.concatenate((const_dict[1], const_dict[2])),
3: np.concatenate((const_dict[4], const_dict[5])),
}
test_mod, const_dict = MergeConstants(const_dict)(InputModule)
reference_mod = ReferenceModule
tvm.ir.assert_structural_equal(test_mod, reference_mod, False)
check_const_dictionaries(const_dict, new_const_dict)


def test_arbitrary_argument_order_const_split():
# fmt: off
@tvm.script.ir_module
class InputModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(96,), "uint8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# buffer definition
T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
# body
p1_data = T.allocate([368], "uint8", "global")
p1 = T.buffer_decl([368], "uint8", data=p1_data)
p2_data = T.allocate([96], "uint8", "global")
p2 = T.buffer_decl([96], "uint8", data=p2_data)
p3_data = T.allocate([368], "uint8", "global")
p3 = T.buffer_decl([368], "uint8", data=p3_data)
p4_data = T.allocate([96], "uint8", "global")
p4 = T.buffer_decl([96], "uint8", data=p4_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None


@tvm.script.ir_module
class ReferenceModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# body
p1_data = T.allocate([464], "uint8", "global")
p1 = T.buffer_decl([464], "uint8", data=p1_data)
p2_data = T.allocate([464], "uint8", "global")
p2 = T.buffer_decl([464], "uint8", data=p2_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None
# fmt: on

const_dict = {
1: np.array([1], dtype=np.uint8),
3: np.array([3], dtype=np.uint8),
4: np.array([4], dtype=np.uint8),
5: np.array([5], dtype=np.uint8),
}
new_const_dict = {
1: np.concatenate((const_dict[1], const_dict[3])),
3: np.concatenate((const_dict[4], const_dict[5])),
}
test_mod, const_dict = MergeConstants(const_dict)(InputModule)
reference_mod = ReferenceModule
tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
check_const_dictionaries(const_dict, new_const_dict)


def test_arbitrary_argument_order_const_split_mixed():
# fmt: off
@tvm.script.ir_module
class InputModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# buffer definition
T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
# body
p1_data = T.allocate([368], "uint8", "global")
p1 = T.buffer_decl([368], "uint8", data=p1_data)
p2_data = T.allocate([368], "uint8", "global")
p2 = T.buffer_decl([368], "uint8", data=p2_data)
p3_data = T.allocate([96], "uint8", "global")
p3 = T.buffer_decl([96], "uint8", data=p3_data)
p4_data = T.allocate([96], "uint8", "global")
p4 = T.buffer_decl([96], "uint8", data=p4_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 96, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p3[0], 48, p3[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 368, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None


@tvm.script.ir_module
class ReferenceModule:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], buffer2: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# body
p1_data = T.allocate([464], "uint8", "global")
p1 = T.buffer_decl([464], "uint8", data=p1_data)
p2_data = T.allocate([464], "uint8", "global")
p2 = T.buffer_decl([464], "uint8", data=p2_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None
# fmt: on

const_dict = {
1: np.array([1], dtype=np.uint8),
2: np.array([2], dtype=np.uint8),
4: np.array([4], dtype=np.uint8),
5: np.array([5], dtype=np.uint8),
}
new_const_dict = {
1: np.concatenate((const_dict[1], const_dict[4])),
2: np.concatenate((const_dict[2], const_dict[5])),
}
test_mod, const_dict = MergeConstants(const_dict)(InputModule)
reference_mod = ReferenceModule
tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
check_const_dictionaries(const_dict, new_const_dict)


def test_cycle_count():
# fmt: off
@tvm.script.ir_module
Expand Down