From cd8b94b38fb160cd0a407c343bbf84b4bc73fc83 Mon Sep 17 00:00:00 2001 From: Nicola Lancellotti Date: Thu, 9 Jun 2022 09:39:50 +0000 Subject: [PATCH 1/4] [microNPU] Add MergeConstants pass Change-Id: I1ff51d8147fba8c66d442a370b9f058e9b2758d8 --- .../backend/contrib/ethosu/tir/compiler.py | 2 + .../backend/contrib/ethosu/tir/passes.py | 35 ++ src/tir/contrib/ethosu/passes.cc | 567 ++++++++++++++++++ .../test_ethosu/cascader/test_integration.py | 10 +- .../test_ethosu/test_encode_constants.py | 244 +++----- .../test_ethosu/test_merge_constants.py | 561 +++++++++++++++++ .../contrib/test_ethosu/test_networks.py | 14 +- .../test_ethosu/test_remove_concatenates.py | 3 - .../test_ethosu/test_replace_conv2d.py | 24 - .../contrib/test_ethosu/test_replace_copy.py | 37 +- .../contrib/test_ethosu/test_scheduler.py | 24 +- 11 files changed, 1286 insertions(+), 235 deletions(-) create mode 100644 tests/python/contrib/test_ethosu/test_merge_constants.py diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index 0fd82378c300..b4896bd85e44 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -90,6 +90,8 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod = tvm.tir.transform.RemoveNoOp()(mod) mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod) mod = ethosu_passes.HoistAllocates()(mod) + if not util.is_striping_enabled(): + mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod) mod = ethosu_passes.CopyComputeReordering()(mod) # When striping is enabled and if storage_rewrite is not run diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 76726132e05d..0efa23d36d7d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -938,3 +938,38 @@ def CopyComputeReordering(max_copy_movements: Optional[int] = None) -> tvm.IRMod The new module with copy and compute nodes reordered. """ return _ffi_api.CopyComputeReordering(max_copy_movements) + + +def MergeConstants(const_dict): + """ + This pass looks for the constants used by each compute operator + and merges them into a single buffer. + Constants written to a buffer with local scope are not merged. + """ + + def mergeConstantsPass(mod): + nonlocal const_dict + try: + mod["main"] + except: + raise tvm.TVMError( + "Expected a single primitive function called 'main'. " + "Please run the MergeConstants pass in conjunction with the LowerToTIR() pass." + ) + + new_const_dict = {} + for param in const_dict.keys(): + new_const_dict[tvm.tir.IntImm("int64", param)] = tvm.nd.array(const_dict[param]) + mod["main"] = mod["main"].with_attr("ethos-u.const-dict", new_const_dict) + + mod = _ffi_api.MergeConstants()(mod) + const_dict = mod["main"].attrs["ethos-u.const-dict"] + mod = _ffi_api.RemoveConstDictAttribute()(mod) + + new_const_dict = {} + for param in const_dict.keys(): + new_const_dict[int(param)] = const_dict[param].numpy() + + return mod, new_const_dict + + return mergeConstantsPass diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 09c359c55abb..d39a19c480c5 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -24,6 +24,7 @@ */ #include #include +#include #include #include @@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional max_copy_movements) TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering") .set_body_typed(CopyComputeReordering); +/*! + * \brief This pass looks for the constants used by each compute operator + * and merges them into a single buffer. + * Constants written to a buffer with local scope are not merged. + */ +class MergeConstantsMutator : public StmtExprMutator { + public: + MergeConstantsMutator() {} + + PrimFunc operator()(PrimFunc main_func, const Map& const_dict) { + // Analyze + Stmt new_body{this->VisitStmt(main_func->body)}; + + // Rewrite + analyze = false; + new_body = rewrite_prim_func_body(new_body); + std::set params_to_delete{}; + auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, ¶ms_to_delete)}; + auto new_params{make_new_params(main_func->params, params_to_delete)}; + + // Make the new const dict + auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)}; + auto buffers_to_merge{ + get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)}; + auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)}; + + // Make the new prim func + auto prim_func_node{main_func.CopyOnWrite()}; + prim_func_node->body = std::move(new_body); + prim_func_node->buffer_map = std::move(new_buffer_map); + prim_func_node->params = std::move(new_params); + prim_func_node->preflattened_buffer_map = {}; + PrimFunc f{GetRef(prim_func_node)}; + + // Add the new const dict as an attribute + f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict); + + return f; + } + + private: + /*! Indicates whether the pass is analyzing or rewriting */ + bool analyze = true; + + /*! A stack to store allocates as they are visited. */ + std::vector allocates{}; + + /*! A list that contains in the i-th position the write buffer of the i-th statement + * if that statement is a copy to a buffer with global scope */ + std::vector> copy_write_buffers{}; + + /*! Maps a copy's write buffer to an index representing the + * new buffer and an offset in that buffer */ + std::map> + old_to_new_write_buffer{}; + + /*! Maps an index representing a new buffer to the length of that buffer */ + std::map new_buffers_length{}; + + /*! Maps an index representing a new buffer to the new buffer */ + std::map new_buffers{}; + + /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */ + std::map cycle_counts{}; + + /*! Maps a copy's read buffer to the new copy's read buffer */ + std::map old_to_new_read_buffers{}; + + /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer + */ + std::map> buffers_to_merge{}; + + /*! A set of buffers to delete */ + std::set buffers_to_delete{}; + + // Visit + + Stmt VisitStmt_(const AllocateNode* op) override { + if (analyze) { + allocates.push_back(GetRef(op)); + return VisitStmt(op->body); + } else { + auto allocate{CopyOnWrite(op)}; + allocate->body = this->VisitStmt(op->body); + return Stmt(allocate); + } + } + + Stmt VisitStmt_(const SeqStmtNode* op) override { + if (op->size() <= 1) { + return StmtExprMutator::VisitStmt_(op); + } + return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op); + } + + Stmt analyze_seq_stmt(const SeqStmtNode* op) { + auto seq_stmt{GetRef(op)}; + + for (size_t i = 0; i < seq_stmt.size(); ++i) { + Stmt stmt{seq_stmt[i]}; + + switch (get_stmt_type(stmt)) { + case StmtType::global_copy: { + Buffer write_buffer{get_copy_write_buffer(stmt)}; + copy_write_buffers.push_back(write_buffer); + old_to_new_write_buffer[write_buffer] = std::make_pair(-1, -1); + break; + } + case StmtType::local_copy: { + copy_write_buffers.push_back(Optional{}); + break; + } + case StmtType::compute: { + copy_write_buffers.push_back(Optional{}); + auto buffers{get_copied_buffers_used_by_stmt(stmt)}; + if (buffers.empty()) { + continue; + } + new_buffers_length[i] = 0; + for (auto buffer : buffers) { + for (size_t j{i - 1}; j >= 0; --j) { + if (copy_write_buffers[j] == buffer) { + old_to_new_write_buffer[buffer] = std::make_pair(i, new_buffers_length[i]); + new_buffers_length[i] += get_copy_length(seq_stmt[j]); + cycle_counts[i] += get_stmt_cycle_counts(seq_stmt[j]); + break; + } + } + } + break; + } + } + } + return seq_stmt; + } + + Stmt rewrite_prim_func_body(Stmt body) { + std::map var_to_allocate{}; + + // Rewrite old allocates + std::set buffer_vars{get_vars_for_written_copy_buffers()}; + for (auto it{allocates.rbegin()}; it != allocates.rend(); ++it) { + Allocate alloc{*it}; + var_to_allocate[alloc->buffer_var.get()] = alloc; + if (buffer_vars.count(alloc->buffer_var) == 0) { + body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body, + alloc->annotations, alloc->span); + } + } + + // Rewrite new allocates + for (auto it{copy_write_buffers.rbegin()}; it != copy_write_buffers.rend(); ++it) { + if (auto buffer_opt = *it) { + Buffer old_write_buffer{buffer_opt.value()}; + int new_buffer_index{old_to_new_write_buffer[old_write_buffer].first}; + + // Check if the allocate has already been created + if (new_buffers.count(new_buffer_index) == 0) { + BufferNode* new_buffer{old_write_buffer.CopyOnWrite()}; + new_buffer->shape = {new_buffers_length[new_buffer_index]}; + + new_buffers[new_buffer_index] = GetRef(new_buffer); + + auto old_allocate{var_to_allocate[old_write_buffer->data.get()]}; + body = Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, tir::const_true(), + body, old_allocate->annotations, old_allocate->span); + } + } + } + + // Rewrite operators + return this->VisitStmt(body); + } + + Stmt rewrite_seq_stmt(const SeqStmtNode* op) { + Array new_seq{}; + + auto seq_stmt{GetRef(op)}; + for (size_t i{0}; i < seq_stmt.size(); ++i) { + Stmt stmt{seq_stmt[i]}; + + switch (get_stmt_type(stmt)) { + case StmtType::global_copy: { + Buffer old_write_buffer{copy_write_buffers[i].value()}; + auto pair{old_to_new_write_buffer[old_write_buffer]}; + auto new_buffer_index{pair.first}; + auto new_buffer_offset{pair.second}; + update_buffers_to_merge_and_delete(stmt, new_buffer_index, new_buffer_offset); + + if (!is_copy_to_be_deleted(new_buffer_offset)) { + auto cycle_counts{get_merged_cycle_counts(new_buffer_index)}; + new_seq.push_back(make_new_stmt( + stmt, make_new_copy_args(stmt, old_write_buffer, new_buffer_index), cycle_counts)); + } + break; + } + case StmtType::local_copy: { + new_seq.push_back(stmt); + break; + } + case StmtType::compute: { + new_seq.push_back(make_new_stmt(stmt, make_new_compute_args(stmt))); + break; + } + } + } + return SeqStmt(new_seq, op->span); + } + + enum class StmtType { global_copy, local_copy, compute }; + + StmtType get_stmt_type(const Stmt& stmt) { + auto args{get_stmt_args(stmt)}; + if (args[0].as()->value == "ethosu_copy") { + if (args[3].as()->buffer.scope() == "global") { + return StmtType::global_copy; + } else { + return StmtType::local_copy; + } + } + return StmtType::compute; + } + + Buffer get_copy_read_buffer(const Stmt& stmt) { + auto args{get_stmt_args(stmt)}; + return args[1].as()->buffer; + } + + Buffer get_copy_write_buffer(const Stmt& stmt) { + auto args{get_stmt_args(stmt)}; + return args[3].as()->buffer; + } + + int64_t get_copy_length(const Stmt& stmt) { + auto args{get_stmt_args(stmt)}; + return args[2].as()->value; + } + + int64_t get_stmt_cycle_counts(const Stmt& stmt) { + auto attr{stmt.as()}; + if (attr && attr->attr_key == "pragma_compute_cycles_hint") { + int64_t cycle_count = Downcast(attr->value); + return cycle_count; + } + return 0; + } + + std::vector get_copied_buffers_used_by_stmt(const Stmt& stmt) { + std::vector buffers{}; + for (auto arg : get_stmt_args(stmt)) { + if (auto buffer_load = arg.as()) { + auto buffer{buffer_load->buffer}; + // Check if the buffer has already been added + if (std::find(buffers.begin(), buffers.end(), buffer) == buffers.end()) { + // Check if the buffer is copied + if (old_to_new_write_buffer.count(buffer)) { + buffers.push_back(buffer); + } + } + } + } + return buffers; + } + + std::set get_vars_for_written_copy_buffers() { + std::set buffer_vars{}; + std::transform(old_to_new_write_buffer.begin(), old_to_new_write_buffer.end(), + std::inserter(buffer_vars, buffer_vars.begin()), + [](auto pair) -> Var { return pair.first->data; }); + return buffer_vars; + } + + tvm::runtime::Array get_stmt_args(const Stmt& stmt) { + auto attr{stmt.as()}; + Stmt eval_stmt{attr ? attr->body : stmt}; + auto eval{eval_stmt.as()}; + ICHECK(eval) << "Expected statement to be an evaluate node, but was " + << eval_stmt->GetTypeKey(); + auto call{eval->value.as()}; + ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey(); + return call->args; + } + + Optional get_merged_cycle_counts(int new_buffer_index) { + auto it = cycle_counts.find(new_buffer_index); + if (it != cycle_counts.end()) { + return Integer(it->second); + } + return Optional{}; + } + + bool is_copy_to_be_deleted(int new_buffer_offset) { return new_buffer_offset > 0; } + + Array make_new_copy_args(const Stmt& stmt, const Buffer& old_write_buffer, + int new_buffer_index) { + Array args{get_stmt_args(stmt)}; + auto new_length{new_buffers_length[new_buffer_index]}; + + Array new_args{}; + for (size_t i = 0; i < args.size(); ++i) { + switch (i) { + case 1: /* read_address */ { + auto buffer_load = args[1].as(); + auto buffer{buffer_load->buffer}; + Buffer new_buffer{buffer->data, + buffer->dtype, + {new_length}, + buffer->strides, + buffer->elem_offset, + buffer->name, + buffer->data_alignment, + buffer->offset_factor, + buffer->buffer_type, + buffer->axis_separators, + buffer->span}; + old_to_new_read_buffers[buffer] = new_buffer; + new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->span)); + break; + } + case 2: /* length */ { + new_args.push_back(new_length); + break; + } + case 3: /* write_address */ { + new_args.push_back(make_new_buffer_load(old_write_buffer, 0, true).value()); + break; + } + default: + new_args.push_back(args[i]); + break; + } + } + return new_args; + } + + Array make_new_compute_args(const Stmt& stmt) { + Array args{get_stmt_args(stmt)}; + Array new_args{}; + for (size_t i = 0; i < args.size(); ++i) { + if (auto buffer_load = args[i].as()) { + auto new_buffer_load{ + make_new_buffer_load(buffer_load->buffer, buffer_load->indices[0], false) + .value_or(GetRef(buffer_load))}; + new_args.push_back(new_buffer_load); + } else { + new_args.push_back(args[i]); + } + } + return new_args; + } + + Stmt make_new_stmt(const Stmt& stmt, const Array& new_args, + Optional cycle_counts = Optional{}) { + auto attr{stmt.as()}; + Stmt eval_stmt{attr ? attr->body : stmt}; + auto eval{eval_stmt.as()}; + ICHECK(eval) << "Expected statement to be an evaluate node, but was " + << eval_stmt->GetTypeKey(); + auto call{eval->value.as()}; + ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey(); + + Call new_call{call->dtype, call->op, new_args, call->span}; + Evaluate new_eval{new_call, eval->span}; + + if (attr) { + ICHECK(attr->attr_key == "pragma_compute_cycles_hint"); + PrimExpr value = cycle_counts.value_or(attr->value); + return AttrStmt{attr->node, attr->attr_key, value, new_eval, attr->span}; + } else { + return new_eval; + } + } + + Optional make_new_buffer_load(const Buffer& write_buffer, const PrimExpr& old_index, + bool only_old_index) { + auto it = old_to_new_write_buffer.find(write_buffer); + if (it != old_to_new_write_buffer.end()) { + auto pair{it->second}; + auto new_buffer_index{pair.first}; + auto new_index{only_old_index ? old_index : (pair.second + old_index)}; + return BufferLoad{new_buffers[new_buffer_index], {new_index}}; + } + return Optional{}; + } + + Map make_new_buffer_map(const Map& buffer_map, + std::set* params_to_delete) { + Map new_buffer_map{}; + for (auto pair : buffer_map) { + Var var{pair.first}; + Buffer buffer{pair.second}; + + if (buffers_to_delete.count(buffer) == 1) { + params_to_delete->insert(var); + } else if (old_to_new_read_buffers.count(buffer) == 1) { + new_buffer_map.Set(var, old_to_new_read_buffers[buffer]); + } else { + new_buffer_map.Set(var, buffer); + } + } + return new_buffer_map; + } + + Array make_new_params(const Array& params, + const std::set& params_to_delete) { + std::vector new_params{}; + for (auto var : params) { + if (params_to_delete.count(var) == 0) { + new_params.push_back(var); + } + } + return new_params; + } + + void update_buffers_to_merge_and_delete(const Stmt& stmt, int new_buffer_index, + int new_buffer_offset) { + Array args{get_stmt_args(stmt)}; + Buffer read_buffer{get_copy_read_buffer(stmt)}; + + if (buffers_to_merge.count(new_buffer_index) == 0) { + buffers_to_merge[new_buffer_index] = std::vector{read_buffer}; + } else { + buffers_to_merge[new_buffer_index].push_back(read_buffer); + } + + if (new_buffer_offset > 0) { + buffers_to_delete.insert(read_buffer); + } + } + + /*! Returns an array whose elements are the indices of the function arguments to be merged. + * Example: if a function has three arguments and the second and the third ones must + * be merged then the array is: [[0], [1, 2], [3]] */ + Array> get_args_to_merge(const Map& buffer_map, + const Array& params) { + std::map buffer_to_var{}; + for (auto var_buffer : buffer_map) { + buffer_to_var[var_buffer.second] = var_buffer.first; + } + + std::map var_to_index{}; + for (int i = 0; i < static_cast(params.size()); ++i) { + var_to_index[params[i]] = i; + } + + std::vector> vector{}; + for (auto index_vector : buffers_to_merge) { + std::vector indices{}; + for (auto buffer : index_vector.second) { + auto var{buffer_to_var[buffer]}; + IntImm index{DataType::Int(64), var_to_index[var]}; + var_to_index.erase(var); + auto it = std::find_if(indices.begin(), indices.end(), + [&](IntImm value) { return value->value == index->value; }); + if (it == indices.end()) { + indices.push_back(index); + } + } + vector.push_back(Array{indices}); + } + + for (auto var_index : var_to_index) { + vector.push_back(Array{IntImm(DataType::Int(64), var_index.second)}); + } + std::sort(vector.begin(), vector.end(), + [](Array a, Array b) { return a[0]->value < b[0]->value; }); + return vector; + } + + Array> get_args_to_merge_without_args_not_in_const_dict( + const Array>& args_to_merge, const Map& const_dict) { + Array> new_args_to_merge{}; + for (auto args : args_to_merge) { + IntImm key{args[0]}; + auto it = std::find_if(const_dict.begin(), const_dict.end(), + [&](std::pair pair) { + return pair.first->value == key->value; + }); + if (it != const_dict.end()) { + new_args_to_merge.push_back(args); + } + } + return new_args_to_merge; + } + + Map make_new_const_dict(const Array>& args_to_merge, + Map const_dict) { + Map new_const_dict{}; + if (args_to_merge.size() == 0) { + return new_const_dict; + } + + int64_t key = args_to_merge[0][0]->value; + for (auto args : args_to_merge) { + int64_t size = 0; + for (auto arg : args) { + auto it = std::find_if(const_dict.begin(), const_dict.end(), + [&](auto pair) { return pair.first->value == arg->value; }); + auto arg_constant{(*it).second}; + size += runtime::GetDataSize(*arg_constant.operator->()); + } + + runtime::NDArray constant = runtime::NDArray::Empty({size}, DataType::UInt(8), {kDLCPU, 0}); + + size_t offset = 0; + for (auto arg : args) { + auto it = std::find_if(const_dict.begin(), const_dict.end(), + [&](auto pair) { return pair.first->value == arg->value; }); + auto arg_constant{(*it).second}; + size_t nbytes = runtime::GetDataSize(*arg_constant.operator->()); + arg_constant.CopyToBytes(static_cast(constant->data) + offset, nbytes); + offset += nbytes; + } + new_const_dict.Set(IntImm(DataType::Int(64), key), constant); + key += 1; + } + return new_const_dict; + } +}; + +/*! + * \brief This pass looks for the constants used by each compute operator + * and merges them into a single buffer. + * Constants written to a buffer with local scope are not merged. + * \return tvm::transform::Pass + */ +tvm::transform::Pass MergeConstants() { + auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) { + ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main")) + << "Expected a single primitive function called 'main'. Please run the " + "MergeConstants pass in conjunction with the LowerToTIR() pass."; + auto const_dict{ + f->attrs.GetAttr("ethos-u.const-dict", Optional>{})}; + ICHECK(const_dict) << "Expected a ethos-u.const-dict attribute"; + return MergeConstantsMutator()(f, const_dict.value()); + }; + return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.MergeConstants", + {}); +} + +TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.MergeConstants").set_body_typed(MergeConstants); + +/*! + * \brief This pass removes the ethos-u.const-dict attribute + * \return tvm::transform::Pass + */ +class RemoveConstDictAttributeMutator : public StmtExprMutator { + public: + RemoveConstDictAttributeMutator() {} + + PrimFunc operator()(PrimFunc main_func) { + return WithoutAttr(std::move(main_func), "ethos-u.const-dict"); + } +}; + +tvm::transform::Pass RemoveConstDictAttribute() { + auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) { + return RemoveConstDictAttributeMutator()(f); + }; + return tvm::tir::transform::CreatePrimFuncPass( + pass_func, 0, "tir.contrib.ethos-u.RemoveConstDictAttribute", {}); +} + +TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.RemoveConstDictAttribute") + .set_body_typed(RemoveConstDictAttribute); + } // namespace ethosu } // namespace contrib } // namespace tir diff --git a/tests/python/contrib/test_ethosu/cascader/test_integration.py b/tests/python/contrib/test_ethosu/cascader/test_integration.py index 8e1f020861d5..14cc8fbc61cf 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_integration.py +++ b/tests/python/contrib/test_ethosu/cascader/test_integration.py @@ -109,9 +109,8 @@ def test_single_conv_compute_cycles_hint(): for single convolution. """ primfunc = _compile_model(_create_single_conv2d()) - ops = primfunc.body.body.body.seq - - compute_cycles_hints = [2304, 640, 320] + ops = primfunc.body.body.seq + compute_cycles_hints = [2944, 320] for op, compute_cycle_hint in zip(ops, compute_cycles_hints): assert op.attr_key == "pragma_compute_cycles_hint" assert op.value == compute_cycle_hint @@ -123,9 +122,8 @@ def test_double_conv_compute_cycles_hint(): for double convolution. """ primfunc = _compile_model(_create_double_conv2d()) - ops = primfunc.body.body.body.body.body.body.seq - - compute_cycles_hints = [2304, 640, 768, 640, 320, 240] + ops = primfunc.body.body.body.body.seq + compute_cycles_hints = [2944, 1408, 320, 240] for op, compute_cycle_hint in zip(ops, compute_cycles_hints): assert op.attr_key == "pragma_compute_cycles_hint" assert op.value == compute_cycle_hint diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 15b719f33c3f..fd9f373739e1 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -37,34 +37,23 @@ class WeightStreamOnlyU55: def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer1 = T.buffer_decl([128], "uint8") - buffer2 = T.buffer_decl([32], "uint8") - buffer3 = T.buffer_decl([112], "uint8") - buffer4 = T.buffer_decl([32], "uint8") - buffer5 = T.buffer_decl([112], "uint8") - buffer6 = T.buffer_decl([32], "uint8") - buffer7 = T.buffer_decl([112], "uint8") + buffer1 = T.buffer_decl([160], "uint8") + buffer3 = T.buffer_decl([144], "uint8") + buffer5 = T.buffer_decl([144], "uint8") + buffer7 = T.buffer_decl([144], "uint8") buffer8 = T.buffer_decl([32], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body - p1 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - p3 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) - p4 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - buffer9 = T.buffer_decl([112], "uint8", data=p1.data) - T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 128, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 32, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p3[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 32, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 112, buffer9[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 32, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, T.int8(-1), T.int8(-1), 12, p4[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 112, p3[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 32, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 112, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, T.int8(-1), T.int8(-1), 12, p4[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + p1 = T.allocate([160], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([144], "uint8", "global", annotations={"disable_lower_builtin":True}) + buffer9 = T.buffer_decl([144], "uint8", data=p1.data) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 160, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 144, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, T.int8(-1), T.int8(-1), 12, p1[128], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 144, buffer9[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, T.int8(-1), T.int8(-1), 12, p2[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 144, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 112, T.int8(-1), T.int8(-1), 12, buffer9[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, T.int8(-1), T.int8(-1), 12, p2[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -75,34 +64,22 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) - buffer_encoded_1 = T.buffer_decl([160], dtype="uint8") - buffer_encoded_1_1 = T.buffer_decl([32], dtype="uint8") - buffer_encoded_2_1 = T.buffer_decl([160], dtype="uint8") - buffer_encoded_3_1 = T.buffer_decl([32], dtype="uint8") - buffer_encoded_4_1 = T.buffer_decl([176], dtype="uint8") - buffer_encoded_5_1 = T.buffer_decl([32], dtype="uint8") - buffer_encoded_6_1 = T.buffer_decl([160], dtype="uint8") - buffer_encoded_7_1 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_1 = T.buffer_decl([192], dtype="uint8") + buffer_encoded_2_1 = T.buffer_decl([192], dtype="uint8") + buffer_encoded_4_1 = T.buffer_decl([208], dtype="uint8") + buffer_encoded_6_1 = T.buffer_decl([192], dtype="uint8") # body - placeholder_global = T.allocate([176], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_global_2 = T.allocate([160], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global_2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_global_1 = T.buffer_decl([160], dtype="uint8", data=placeholder_global.data) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 160, placeholder_global_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1_1[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2_1[0], 160, placeholder_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3_1[0], 32, placeholder_d_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 80, placeholder_global_1[80], 80, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4_1[0], 176, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5_1[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 80, placeholder_global_2[80], 80, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6_1[0], 160, placeholder_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7_1[0], 32, placeholder_d_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 96, placeholder_global[96], 80, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 80, placeholder_global_2[80], 80, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + p1 = T.allocate([208], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([192], "uint8", "global", annotations={"disable_lower_builtin":True}) + p3 = T.buffer_decl([192], dtype="uint8", data=p1.data) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 192, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2_1[0], 192, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 80, p3[80], 80, 12, p3[160], 16, p3[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4_1[0], 208, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, p2[80], 80, 12, p2[160], 16, p2[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6_1[0], 192, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 96, p1[96], 80, 12, p1[176], 16, p1[192], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, p2[80], 80, 12, p2[160], 16, p2[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -113,12 +90,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), ( "ethos-u55-128", WeightStreamOnlyU55, - [128, 32, 112, 32, 112, 32, 112, 32], + [160, 144, 144, 144], ), ( "ethos-u65-512", WeightStreamOnlyU65, - [160, 32, 160, 32, 176, 32, 160, 32], + [192, 192, 208, 192], ), ], ) @@ -160,7 +137,7 @@ def _get_func(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + assert reference_const_sizes.sort() == test_const_size.sort() # fmt: off @@ -170,21 +147,14 @@ class RereadWeightsU55: def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer1 = T.buffer_decl([304], "uint8") - buffer2 = T.buffer_decl([80], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) + buffer1 = T.buffer_decl([384], "uint8") # body - p1 = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - p3 = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True}) - p4 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 304, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 304, p3[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 304, T.int8(-1), T.int8(-1), 12, p2[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 304, T.int8(-1), T.int8(-1), 12, p4[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + p1 = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 304, T.int8(-1), T.int8(-1), 12, p1[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 304, T.int8(-1), T.int8(-1), 12, p2[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -195,21 +165,14 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) - placeholder_encoded_1 = T.buffer_decl([368], "uint8") - placeholder_encoded_1_2 = T.buffer_decl([96], "uint8") + placeholder_encoded_1 = T.buffer_decl([464], "uint8") # body - placeholder_global = T.allocate([368], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_global_1 = T.allocate([368], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global_1 = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 368, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1_2[0], 96, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 368, placeholder_global_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1_2[0], 96, placeholder_d_global_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 192, placeholder_global[192], 176, 12, placeholder_d_global[0], 48, placeholder_d_global[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 192, placeholder_global_1[192], 176, 12, placeholder_d_global_1[0], 48, placeholder_d_global_1[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + p1 = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -221,12 +184,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), ( "ethos-u55-128", RereadWeightsU55, - [304, 80], + [384], ), ( "ethos-u65-512", RereadWeightsU65, - [368, 96], + [464], ), ], ) @@ -268,7 +231,7 @@ def _get_func(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + assert reference_const_sizes.sort() == test_const_size.sort() # fmt: off @@ -282,8 +245,6 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), buffer_1 = T.buffer_decl([160], "uint8") buffer_2 = T.buffer_decl([160], "uint8") buffer_3 = T.buffer_decl([80], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -302,8 +263,6 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), placeholder_encoded_1 = T.buffer_decl([160], dtype="uint8") placeholder_encoded_2 = T.buffer_decl([208], dtype="uint8") placeholder_encoded_3 = T.buffer_decl([96], dtype="uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) # body ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -364,87 +323,64 @@ def _get_func(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + assert reference_const_sizes.sort() == test_const_size.sort() # fmt: off @tvm.script.ir_module class MixedReadU55: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(112,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer1 = T.buffer_decl([80], "uint8") - buffer2 = T.buffer_decl([32], "uint8") - buffer3 = T.buffer_decl([80], "uint8") - buffer4 = T.buffer_decl([32], "uint8") - buffer5 = T.buffer_decl([80], "uint8") - buffer6 = T.buffer_decl([32], "uint8") - buffer7 = T.buffer_decl([80], "uint8") - buffer8 = T.buffer_decl([32], "uint8") + buffer1 = T.buffer_decl([112], "uint8") + buffer3 = T.buffer_decl([112], "uint8") + buffer5 = T.buffer_decl([112], "uint8") buffer9 = T.buffer_decl([592], "uint8") buffer10 = T.buffer_decl([160], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) + buffer11 = T.buffer_decl([2048], "int8") # body - p1 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) p3 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - p4 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - p5 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 80, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 32, p2[0], dtype="handle")) + p2 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 112, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 80, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 32, p5[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 80, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 32, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 80, T.int8(-1), T.int8(-1), 12, p5[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 80, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 32, p5[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 80, T.int8(-1), T.int8(-1), 12, p5[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 112, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 112, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class MixedReadU65: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) + # buffer definition - buffer_encoded_1 = T.buffer_decl([96], dtype="uint8") - buffer_encoded_1_2 = T.buffer_decl([32], dtype="uint8") - placeholder_encoded_1 = T.buffer_decl([608], dtype="uint8") - placeholder_encoded_1_2 = T.buffer_decl([160], dtype="uint8") - buffer_encoded_2_1 = T.buffer_decl([96], dtype="uint8") - buffer_encoded_3_1 = T.buffer_decl([32], dtype="uint8") - buffer_encoded_4_1 = T.buffer_decl([96], dtype="uint8") - buffer_encoded_5_1 = T.buffer_decl([32], dtype="uint8") - buffer_encoded_6_1 = T.buffer_decl([96], dtype="uint8") - buffer_encoded_7_1 = T.buffer_decl([32], dtype="uint8") - placeholder_global = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - placeholder_global_2 = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global_2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 96, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1_2[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded_1[0], 304, placeholder_encoded_1[304], 304, 12, placeholder_encoded_1_2[0], 80, placeholder_encoded_1_2[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2_1[0], 96, placeholder_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3_1[0], 32, placeholder_d_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 48, placeholder_global[48], 48, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4_1[0], 96, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5_1[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 48, placeholder_global_2[48], 48, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6_1[0], 96, placeholder_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7_1[0], 32, placeholder_d_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 48, placeholder_global[48], 48, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 48, placeholder_global_2[48], 48, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + buffer1 = T.buffer_decl([128], dtype="uint8") + buffer2 = T.buffer_decl([128], dtype="uint8") + buffer3 = T.buffer_decl([128], dtype="uint8") + buffer4 = T.buffer_decl([608], dtype="uint8") + buffer5 = T.buffer_decl([160], dtype="uint8") + buffer6 = T.buffer_decl([2048], dtype="int8") + p1 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + p3 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 128, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -455,12 +391,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), ( "ethos-u55-128", MixedReadU55, - [592, 160, 80, 32, 80, 32, 80, 32, 80, 32], + [592, 160, 112, 112, 112, 112], ), ( "ethos-u65-512", MixedReadU65, - [608, 160, 96, 32, 96, 32, 96, 32, 96, 32], + [608, 160, 128, 128, 128, 128], ), ], ) @@ -512,7 +448,7 @@ def _get_func(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + assert reference_const_sizes.sort() == test_const_size.sort() def test_constant_as_input(): @@ -543,7 +479,7 @@ def get_graph(): # Check tile address for the scalar constant input hasn't been # overwritten. - extern_calls = tir_mod["main"].body.body.body.body.body + extern_calls = tir_mod["main"].body.body.body.body binary_elementwise = extern_calls[-1].value args = binary_elementwise.args diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py new file mode 100644 index 000000000000..caf09abdb020 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -0,0 +1,561 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import tvm +from tvm.script import tir as T +from tvm.relay.backend.contrib.ethosu.tir.passes import MergeConstants +import numpy as np + + +def check_const_dictionaries(const_dict, new_const_dict): + assert list(const_dict) == list(new_const_dict) + for key, value in const_dict.items(): + new_value = new_const_dict[key] + assert len(value) == len(new_value) + for i in range(len(value)): + assert value[i] == new_value[i] + + +def test_only_one_operator(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p1 = T.allocate([128], "uint8", "global") + p4 = T.allocate([32], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer2: T.Buffer[(160,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p4 = T.allocate([160], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + const_dict = { + 0: np.array([0, 10], dtype=np.uint8), + 1: np.array([1, 11], dtype=np.uint8), + } + new_const_dict = {0: np.concatenate((const_dict[0], const_dict[1]))} + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_all_operators_with_weights(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"], buffer4: T.Buffer[(112,), "uint8"], buffer5: T.Buffer[(32,), "uint8"], buffer6: T.Buffer[(112,), "uint8"], buffer7: T.Buffer[(32,), "uint8"], buffer8: T.Buffer[(112,), "uint8"], buffer9: T.Buffer[(32,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p1 = T.allocate([128], "uint8", "global") + p2 = T.allocate([112], "uint8", "global") + p3 = T.allocate([112], "uint8", "global") + p4 = T.allocate([32], "uint8", "global") + p5 = T.allocate([32], "uint8", "global") + p6 = T.allocate([32], "uint8", "global") + p7 = T.allocate([112], "uint8", "global") + p8 = T.allocate([3], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 112, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 32, p5[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 112, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 32, p6[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, 12, p5[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 112, p7[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer9[0], 32, p8[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, 12, p6[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p8[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer2: T.Buffer[(160,), "uint8"], buffer4: T.Buffer[(144,), "uint8"], buffer6: T.Buffer[(144,), "uint8"], buffer8: T.Buffer[(144,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p4 = T.allocate([160], "uint8", "global") + p7 = T.allocate([144], "uint8", "global") + p10 = T.allocate([144], "uint8", "global") + p11 = T.allocate([144], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 144, p7[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 144, p10[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p7[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 144, p11[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p10[0], 112, 12, p10[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p11[0], 112, 12, p11[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = { + 0: np.array([0], dtype=np.uint8), + 1: np.array([1], dtype=np.uint8), + 2: np.array([2], dtype=np.uint8), + 3: np.array([3], dtype=np.uint8), + 4: np.array([4], dtype=np.uint8), + 5: np.array([5], dtype=np.uint8), + 6: np.array([6], dtype=np.uint8), + 7: np.array([7], dtype=np.uint8), + } + new_const_dict = { + 0: np.concatenate((const_dict[0], const_dict[1])), + 1: np.concatenate((const_dict[2], const_dict[3])), + 2: np.concatenate((const_dict[4], const_dict[5])), + 3: np.concatenate((const_dict[6], const_dict[7])), + } + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_operators_with_and_without_weights(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer2: T.Buffer[(80,), "uint8"], buffer3: T.Buffer[(64,), "uint8"]) -> None: + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer0 = T.buffer_decl([390336], "int8") + buffer1 = T.buffer_decl([97156], "int8") + buffer6 = T.buffer_decl([390336], "int8") + # body + p2 = T.allocate([80], "uint8", "global") + p3 = T.allocate([64], "uint8", "global") + T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, buffer6[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p2[0], 80, 0, p3[0], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer2: T.Buffer[(144,), "uint8"]) -> None: + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer0 = T.buffer_decl([390336], "int8") + buffer1 = T.buffer_decl([97156], "int8") + buffer6 = T.buffer_decl([390336], "int8") + # body + p3 = T.allocate([144], "uint8", "global") + T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 144, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, buffer6[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p3[0], 80, 0, p3[80], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = { + 0: np.array([0], dtype=np.uint8), + 1: np.array([1], dtype=np.uint8), + } + new_const_dict = {0: np.concatenate((const_dict[0], const_dict[1]))} + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_copy_to_buffer_with_local_scope(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer1: T.Buffer[(64,), "uint8"], + buffer2: T.Buffer[(48,), "uint8"], + buffer3: T.Buffer[(256,), "uint8"], + buffer4: T.Buffer[(256,), "uint8"], + buffer5: T.Buffer[(16,), "uint8"], + buffer6: T.Buffer[(48,), "uint8"], + buffer7: T.Buffer[(256,), "uint8"], + buffer8: T.Buffer[(64,), "uint8"], + buffer9: T.Buffer[(256,), "int8"], + ) -> None: + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # body + p1 = T.allocate([48], "uint8", "global") + p2 = T.allocate([48], "uint8", "global") + p3 = T.allocate([256], "int8", "local") + p5 = T.allocate([16], "uint8", "global") + p6 = T.allocate([48], "uint8", "global") + p7 = T.allocate([256], "int8", "local") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 48, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 48, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 256, p3[0], dtype="handle")) # Local + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 16, p5[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 48, p6[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer1[0], 0, 0, 0, T.float32(0.00392081), -128, "NHWC", 16, 4, 1, "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.00839574), -128, "NHCWB16", 64, 16, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, 0, p2[0], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 256, p7[0], dtype="handle")) # Local + T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.0078125), 0, "NHCWB16", 64, 16, 1, "int8", 4, 4, 4, 4, 0, 4, buffer8[0], 0, 0, 0, T.float32(0.00372155), -128, "NHWC", 16, 4, 1, 1, 1, 1, 1, 1, 1, p5[0], 16, 0, p6[0], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer1: T.Buffer[(64,), "uint8"], + buffer2: T.Buffer[(96,), "uint8"], + buffer4: T.Buffer[(256,), "uint8"], + buffer5: T.Buffer[(64,), "uint8"], + buffer7: T.Buffer[(256,), "uint8"], + buffer8: T.Buffer[(64,), "uint8"], + buffer9: T.Buffer[(256,), "int8"], + ) -> None: + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # body + p1 = T.allocate([96], "uint8", "global") + p2 = T.allocate([64], "uint8", "global") + p3 = T.allocate([256], "int8", "local") + p7 = T.allocate([256], "int8", "local") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 256, p3[0], dtype="handle")) # Local + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 64, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer1[0], 0, 0, 0, T.float32(0.00392081), -128, "NHWC", 16, 4, 1, "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.00839574), -128, "NHCWB16", 64, 16, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, 0, p1[48], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 256, p7[0], dtype="handle")) # Local + T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.0078125), 0, "NHCWB16", 64, 16, 1, "int8", 4, 4, 4, 4, 0, 4, buffer8[0], 0, 0, 0, T.float32(0.00372155), -128, "NHWC", 16, 4, 1, 1, 1, 1, 1, 1, 1, p2[0], 16, 0, p2[16], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = { + 1: np.array([1], dtype=np.uint8), + 2: np.array([2], dtype=np.uint8), + 3: np.array([3], dtype=np.uint8), + 4: np.array([4], dtype=np.uint8), + 5: np.array([5], dtype=np.uint8), + 6: np.array([6], dtype=np.uint8), + } + new_const_dict = { + 1: np.concatenate((const_dict[1], const_dict[2])), + 2: const_dict[3], + 3: np.concatenate((const_dict[4], const_dict[5])), + 4: const_dict[6], + } + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_no_copies(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main() -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder = T.buffer_decl([20], "int8") + ethosu_write = T.buffer_decl([16], "int8") + # body + ethosu_write_4 = T.allocate([16], "int8", "global") + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 1, 4, 4, 1, 0, 4, placeholder[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "int8", 1, 4, 1, 1, 0, 4, placeholder[16], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 1, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "MAX", 0, "CLIP", -128, 127, "TFL", 1, 4, 4, dtype="handle")) + T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main() -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder = T.buffer_decl([20], "int8") + ethosu_write = T.buffer_decl([16], "int8") + # body + ethosu_write_4 = T.allocate([16], "int8", "global") + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 1, 4, 4, 1, 0, 4, placeholder[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "int8", 1, 4, 1, 1, 0, 4, placeholder[16], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 1, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "MAX", 0, "CLIP", -128, 127, "TFL", 1, 4, 4, dtype="handle")) + T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = {} + new_const_dict = {} + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_copies_to_the_same_buffer(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p1 = T.allocate([128], "uint8", "global") + p4 = T.allocate([32], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer2: T.Buffer[(160,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p5 = T.allocate([160], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p5[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5[0], 128, 12, p5[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p5[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5[0], 128, 12, p5[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = { + 0: np.array([0], dtype=np.uint8), + 1: np.array([1], dtype=np.uint8), + } + new_const_dict = {0: np.concatenate((const_dict[0], const_dict[1]))} + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_read_from_the_same_buffer(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + p1 = T.allocate([368], "uint8", "global") + p2 = T.allocate([96], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # body + p1 = T.allocate([464], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + # fmt: on + + const_dict = { + 1: np.array([1], dtype=np.uint8), + 2: np.array([2], dtype=np.uint8), + } + new_const_dict = {1: np.concatenate((const_dict[1], const_dict[2]))} + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_cycle_count(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"], buffer4: T.Buffer[(112,), "uint8"], buffer5: T.Buffer[(32,), "uint8"], buffer6: T.Buffer[(112,), "uint8"], buffer7: T.Buffer[(32,), "uint8"], buffer8: T.Buffer[(112,), "uint8"], buffer9: T.Buffer[(32,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + v1a = T.var("int32") + v1b = T.var("int32") + v1c = T.var("int32") + v2a = T.var("int32") + v2b = T.var("int32") + v2c = T.var("int32") + v3a = T.var("int32") + v3b = T.var("int32") + v3c = T.var("int32") + v4a = T.var("int32") + v4b = T.var("int32") + v4c = T.var("int32") + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p1 = T.allocate([128], "uint8", "global") + p2 = T.allocate([112], "uint8", "global") + p3 = T.allocate([112], "uint8", "global") + p4 = T.allocate([32], "uint8", "global") + p5 = T.allocate([32], "uint8", "global") + p6 = T.allocate([32], "uint8", "global") + p7 = T.allocate([112], "uint8", "global") + p8 = T.allocate([3], "uint8", "global") + with T.attr(T.iter_var(v1a, None, "DataPar", ""), "pragma_compute_cycles_hint", 100): + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) + with T.attr(T.iter_var(v1b, None, "DataPar", ""), "pragma_compute_cycles_hint", 101): + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) + with T.attr(T.iter_var(v2a, None, "DataPar", ""), "pragma_compute_cycles_hint", 102): + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 112, p2[0], dtype="handle")) + with T.attr(T.iter_var(v2b, None, "DataPar", ""), "pragma_compute_cycles_hint", 103): + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 32, p5[0], dtype="handle")) + with T.attr(T.iter_var(v1c, None, "DataPar", ""), "pragma_compute_cycles_hint", 300): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v3a, None, "DataPar", ""), "pragma_compute_cycles_hint", 104): + T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 112, p3[0], dtype="handle")) + with T.attr(T.iter_var(v3b, None, "DataPar", ""), "pragma_compute_cycles_hint", 105): + T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 32, p6[0], dtype="handle")) + with T.attr(T.iter_var(v2c, None, "DataPar", ""), "pragma_compute_cycles_hint", 301): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, 12, p5[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v4a, None, "DataPar", ""), "pragma_compute_cycles_hint", 106): + T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 112, p7[0], dtype="handle")) + with T.attr(T.iter_var(v4b, None, "DataPar", ""), "pragma_compute_cycles_hint", 107): + T.evaluate(T.call_extern("ethosu_copy", buffer9[0], 32, p8[0], dtype="handle")) + with T.attr(T.iter_var(v3c, None, "DataPar", ""), "pragma_compute_cycles_hint", 302): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, 12, p6[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v4c, None, "DataPar", ""), "pragma_compute_cycles_hint", 303): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p8[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer2: T.Buffer[(160,), "uint8"], buffer4: T.Buffer[(144,), "uint8"], buffer6: T.Buffer[(144,), "uint8"], buffer8: T.Buffer[(144,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + v1a = T.var("int32") + v1c = T.var("int32") + v2a = T.var("int32") + v2c = T.var("int32") + v3a = T.var("int32") + v3c = T.var("int32") + v4a = T.var("int32") + v4c = T.var("int32") + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p4 = T.allocate([160], "uint8", "global") + p7 = T.allocate([144], "uint8", "global") + p10 = T.allocate([144], "uint8", "global") + p11 = T.allocate([144], "uint8", "global") + with T.attr(T.iter_var(v1a, None, "DataPar", ""), "pragma_compute_cycles_hint", 201): + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle")) + with T.attr(T.iter_var(v2a, None, "DataPar", ""), "pragma_compute_cycles_hint", 205): + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 144, p7[0], dtype="handle")) + with T.attr(T.iter_var(v1c, None, "DataPar", ""), "pragma_compute_cycles_hint", 300): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v3a, None, "DataPar", ""), "pragma_compute_cycles_hint", 209): + T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 144, p10[0], dtype="handle")) + with T.attr(T.iter_var(v2c, None, "DataPar", ""), "pragma_compute_cycles_hint", 301): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p7[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v4a, None, "DataPar", ""), "pragma_compute_cycles_hint", 213): + T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 144, p11[0], dtype="handle")) + with T.attr(T.iter_var(v3c, None, "DataPar", ""), "pragma_compute_cycles_hint", 302): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p10[0], 112, 12, p10[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v4c, None, "DataPar", ""), "pragma_compute_cycles_hint", 303): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p11[0], 112, 12, p11[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = { + 0: np.array([0], dtype=np.uint8), + 1: np.array([1], dtype=np.uint8), + 2: np.array([2], dtype=np.uint8), + 3: np.array([3], dtype=np.uint8), + 4: np.array([4], dtype=np.uint8), + 5: np.array([5], dtype=np.uint8), + 6: np.array([6], dtype=np.uint8), + 7: np.array([7], dtype=np.uint8), + } + new_const_dict = { + 0: np.concatenate((const_dict[0], const_dict[1])), + 1: np.concatenate((const_dict[2], const_dict[3])), + 2: np.concatenate((const_dict[4], const_dict[5])), + 3: np.concatenate((const_dict[6], const_dict[7])), + } + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_multiple_prim_funcs(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(): + T.evaluate(0) + + @T.prim_func + def abc(): + T.evaluate(0) + # fmt: on + + err_rgx = ( + r"Expected a single primitive function called 'main'. " + r"Please run the MergeConstants pass in conjunction with the LowerToTIR\(\) pass." + ) + with pytest.raises(tvm.TVMError, match=err_rgx): + MergeConstants({})(InputModule) + + +def test_no_main_prim_func(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def abs(): + T.evaluate(0) + # fmt: on + + err_rgx = ( + r"Expected a single primitive function called 'main'. " + r"Please run the MergeConstants pass in conjunction with the LowerToTIR\(\) pass." + ) + with pytest.raises(tvm.TVMError, match=err_rgx): + MergeConstants({})(InputModule) diff --git a/tests/python/contrib/test_ethosu/test_networks.py b/tests/python/contrib/test_ethosu/test_networks.py index 075565cd92a6..02643f6c1ded 100644 --- a/tests/python/contrib/test_ethosu/test_networks.py +++ b/tests/python/contrib/test_ethosu/test_networks.py @@ -44,13 +44,13 @@ @pytest.mark.parametrize( "accel_type, model_url, workspace_size", [ - ("ethos-u65-256", MOBILENET_V1_URL, 1892704), - ("ethos-u65-256", MOBILENET_V2_URL, 2257984), - ("ethos-u55-256", MOBILENET_V1_URL, 1892704), - ("ethos-u55-256", MOBILENET_V2_URL, 2257984), - ("ethos-u55-128", MOBILENET_V2_URL, 2257984), - ("ethos-u55-64", MOBILENET_V2_URL, 2257984), - ("ethos-u55-32", MOBILENET_V2_URL, 2258000), + ("ethos-u65-256", MOBILENET_V1_URL, 1793376), + ("ethos-u65-256", MOBILENET_V2_URL, 2218160), + ("ethos-u55-256", MOBILENET_V1_URL, 1793376), + ("ethos-u55-256", MOBILENET_V2_URL, 2218160), + ("ethos-u55-128", MOBILENET_V2_URL, 2218160), + ("ethos-u55-64", MOBILENET_V2_URL, 2218160), + ("ethos-u55-32", MOBILENET_V2_URL, 2218160), ], ) def test_networks_without_usmp(accel_type, model_url, workspace_size): diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index cc996e59412c..d2c759a0ae4d 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -41,9 +41,6 @@ def main(placeholder: T.Buffer[(1536,), "int8"], placeholder_1: T.Buffer[(1280,) buffer_5 = T.buffer_decl([160], "uint8") buffer_6 = T.buffer_decl([2992], "uint8") buffer_7 = T.buffer_decl([160], "uint8") - T.preflattened_buffer(placeholder, [1, 8, 12, 16], "int8", data=placeholder.data) - T.preflattened_buffer(placeholder_1, [1, 8, 10, 16], "int8", data=placeholder_1.data) - T.preflattened_buffer(T_concat, [1, 8, 32, 16], "int8", data=T_concat.data) # body T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 63f9fc44c778..46a3c5a15bf5 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -373,8 +373,6 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, buffer_1 = T.buffer_decl([80], "uint8") buffer_2 = T.buffer_decl([320], "uint8") buffer_3 = T.buffer_decl([160], "uint8") - T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -394,8 +392,6 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, buffer_1 = T.buffer_decl([320], "uint8") buffer_2 = T.buffer_decl([1312], "uint8") buffer_3 = T.buffer_decl([2608], "uint8") - T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -415,8 +411,6 @@ def main(placeholder_5: T.Buffer[(768,), "int8"], ethosu_write_1: T.Buffer[(640, buffer_1 = T.buffer_decl([80], "uint8") buffer_2 = T.buffer_decl([320], "uint8") buffer_3 = T.buffer_decl([880], "uint8") - T.preflattened_buffer(placeholder_5, [1, 16, 16, 3], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 20, 4, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -438,8 +432,6 @@ def main(placeholder_5: T.Buffer[(1024,), "int8"], ethosu_write_1: T.Buffer[(204 buffer_1 = T.buffer_decl([352], "uint8") buffer_2 = T.buffer_decl([272], "uint8") buffer_3 = T.buffer_decl([11040], "uint8") - T.preflattened_buffer(placeholder_5, [1, 8, 1, 8, 16], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 2, 8, 16], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -459,8 +451,6 @@ def main(placeholder: T.Buffer[(192,), "int8"], ethosu_write: T.Buffer[(8192,), buffer_1 = T.buffer_decl([320], "uint8") buffer_2 = T.buffer_decl([304], "uint8") buffer_3 = T.buffer_decl([80], "uint8") - T.preflattened_buffer(placeholder, [1, 8, 8, 3], 'int8', data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 32, 32, 8], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) @@ -480,8 +470,6 @@ def main(placeholder: T.Buffer[(1024,), "int8"], ethosu_write: T.Buffer[(32768,) buffer_1 = T.buffer_decl([352], "uint8") buffer_2 = T.buffer_decl([11040], "uint8") buffer_3 = T.buffer_decl([272], "uint8") - T.preflattened_buffer(placeholder, [1, 8, 1, 8, 16], 'int8', data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 32, 2, 32, 16], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) @@ -641,8 +629,6 @@ def main(placeholder_3: T.Buffer[(960,), "int8"], ethosu_write_1: T.Buffer[(1024 T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([848], "uint8") buffer_1 = T.buffer_decl([160], "uint8") - T.preflattened_buffer(placeholder_3, [1, 10, 12, 8], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -656,8 +642,6 @@ def main(placeholder_3: T.Buffer[(315,), "int8"], ethosu_write_1: T.Buffer[(240, T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([656], "uint8") - T.preflattened_buffer(placeholder_3, [1, 7, 9, 5], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 3, 5, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -700,8 +684,6 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [4, 6, 8, 1], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -716,8 +698,6 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [1, 24, 8], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -732,8 +712,6 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [192, 1], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -748,8 +726,6 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [192], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 932df71d2402..6b97b38d80e6 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -34,16 +34,11 @@ class ReferenceModule: def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([80], "uint8") - buffer_1 = T.buffer_decl([304], "uint8") - T.preflattened_buffer(placeholder_3, [1, 16, 16, 32], dtype="int8", data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 8], dtype="int8", data=ethosu_write_1.data) + buffer_1 = T.buffer_decl([384], "uint8") # body - placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 304, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer[0], 80, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + placeholder_global = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin": True}) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 384, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_global[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -80,23 +75,15 @@ class WeightStream: def main(placeholder_5: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(4096,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([416], "uint8") - buffer_1 = T.buffer_decl([112], "uint8") - buffer_2 = T.buffer_decl([272], "uint8") - buffer_3 = T.buffer_decl([64], "uint8") - T.preflattened_buffer(placeholder_5, [1, 16, 16, 32], dtype="int8", data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 16], dtype="int8", data=ethosu_write_1.data) + buffer = T.buffer_decl([528], "uint8") + buffer_2 = T.buffer_decl([336], "uint8") # body - placeholder_global_unrolled_iter_0 = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global_unrolled_iter_0 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_global_unrolled_iter_1 = T.allocate([272], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global_unrolled_iter_1 = T.allocate([64], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", buffer[0], 416, placeholder_global_unrolled_iter_0[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 112, placeholder_d_global_unrolled_iter_0[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 272, placeholder_global_unrolled_iter_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 64, placeholder_d_global_unrolled_iter_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_0[0], 416, T.int8(-1), T.int8(-1), 12, placeholder_d_global_unrolled_iter_0[0], 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_1[0], 272, T.int8(-1), T.int8(-1), 12, placeholder_d_global_unrolled_iter_1[0], 64, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + placeholder_d_global = T.allocate([528], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global_1 = T.allocate([336], "uint8", "global", annotations={"disable_lower_builtin": True}) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 528, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 336, placeholder_d_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_d_global[0], 416, T.int8(-1), T.int8(-1), 12, placeholder_d_global[416], 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_d_global_1[0], 272, T.int8(-1), T.int8(-1), 12, placeholder_d_global_1[272], 64, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 4baea26e591e..ba050de2b473 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -182,24 +182,16 @@ class DiamondGraphTir: @T.prim_func def main(placeholder: T.Buffer[(301056,), "int8"], ethosu_write: T.Buffer[(75264,), "int8"]) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(placeholder, [1, 56, 56, 96], dtype='int8', data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 56, 56, 24], dtype='int8', data=ethosu_write.data) - buffer1 = T.buffer_decl([2608], "uint8") - buffer2 = T.buffer_decl([240], "uint8") - buffer3 = T.buffer_decl([736], "uint8") - buffer4 = T.buffer_decl([240], "uint8") - p1 = T.allocate([2608], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True}) - p3 = T.allocate([736], "uint8", "global", annotations={"disable_lower_builtin":True}) - p4 = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True}) + buffer1 = T.buffer_decl([2848], "uint8") + buffer3 = T.buffer_decl([976], "uint8") + p1 = T.allocate([2848], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([976], "uint8", "global", annotations={"disable_lower_builtin":True}) p5 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) p6 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 2608, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 240, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 736, p3[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 240, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p1[0], 2608, T.int8(-1), T.int8(-1), 12, p2[0], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, p6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p3[0], 736, T.int8(-1), T.int8(-1), 12, p4[0], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 2848, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 976, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p1[0], 2608, T.int8(-1), T.int8(-1), 12, p1[2608], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, p6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p2[0], 736, T.int8(-1), T.int8(-1), 12, p2[736], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0,T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, p6[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on From 2b8f95a61770fe346dac9644dbf11fd9ea073f64 Mon Sep 17 00:00:00 2001 From: Nicola Lancellotti Date: Thu, 7 Jul 2022 13:22:36 +0000 Subject: [PATCH 2/4] Fix errors and warnings Change-Id: I29f68f83a73fa00ca34ed0ab2321c53c6b761137 --- src/tir/contrib/ethosu/passes.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index d39a19c480c5..93c1a11ea430 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -357,7 +357,7 @@ class MergeConstantsMutator : public StmtExprMutator { } } } - return seq_stmt; + return std::move(seq_stmt); } Stmt rewrite_prim_func_body(Stmt body) { @@ -465,7 +465,7 @@ class MergeConstantsMutator : public StmtExprMutator { int64_t get_stmt_cycle_counts(const Stmt& stmt) { auto attr{stmt.as()}; if (attr && attr->attr_key == "pragma_compute_cycles_hint") { - int64_t cycle_count = Downcast(attr->value); + int64_t cycle_count{Downcast(attr->value)->value}; return cycle_count; } return 0; @@ -593,7 +593,7 @@ class MergeConstantsMutator : public StmtExprMutator { PrimExpr value = cycle_counts.value_or(attr->value); return AttrStmt{attr->node, attr->attr_key, value, new_eval, attr->span}; } else { - return new_eval; + return std::move(new_eval); } } From 43bcaad36b588b04e6267cb53db0e718b8376f6a Mon Sep 17 00:00:00 2001 From: Nicola Lancellotti Date: Fri, 8 Jul 2022 11:01:24 +0000 Subject: [PATCH 3/4] Address comments Change-Id: Iad59107d5abdec6b079c6fd4ab48c6bffbb5e0bb --- .../backend/contrib/ethosu/tir/compiler.py | 2 + .../backend/contrib/ethosu/tir/passes.py | 8 +- src/tir/contrib/ethosu/passes.cc | 584 +++++++++--------- 3 files changed, 308 insertions(+), 286 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index b4896bd85e44..85c6df4c7d0c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -90,6 +90,8 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod = tvm.tir.transform.RemoveNoOp()(mod) mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod) mod = ethosu_passes.HoistAllocates()(mod) + # MergeConstant pass currently does not support striped schedules. + # It requires further investigation. if not util.is_striping_enabled(): mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod) mod = ethosu_passes.CopyComputeReordering()(mod) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 0efa23d36d7d..c0b017e703ce 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -947,7 +947,7 @@ def MergeConstants(const_dict): Constants written to a buffer with local scope are not merged. """ - def mergeConstantsPass(mod): + def _merge_constants(mod): nonlocal const_dict try: mod["main"] @@ -960,10 +960,10 @@ def mergeConstantsPass(mod): new_const_dict = {} for param in const_dict.keys(): new_const_dict[tvm.tir.IntImm("int64", param)] = tvm.nd.array(const_dict[param]) - mod["main"] = mod["main"].with_attr("ethos-u.const-dict", new_const_dict) + mod["main"] = mod["main"].with_attr("ethos-u.const_dict", new_const_dict) mod = _ffi_api.MergeConstants()(mod) - const_dict = mod["main"].attrs["ethos-u.const-dict"] + const_dict = mod["main"].attrs["ethos-u.const_dict"] mod = _ffi_api.RemoveConstDictAttribute()(mod) new_const_dict = {} @@ -972,4 +972,4 @@ def mergeConstantsPass(mod): return mod, new_const_dict - return mergeConstantsPass + return _merge_constants diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 93c1a11ea430..8ec3185b7d88 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -29,6 +29,8 @@ #include #include +#include +#include namespace tvm { @@ -43,6 +45,62 @@ namespace tir { namespace contrib { namespace ethosu { +namespace { + +/*! Returns the arguments of the given statement */ +Array GetStmtArgs(const Stmt& stmt) { + auto attr{stmt.as()}; + Stmt eval_stmt{attr ? attr->body : stmt}; + auto eval{eval_stmt.as()}; + ICHECK(eval) << "Expected statement to be an evaluate node, but was " << eval_stmt->GetTypeKey(); + auto call{eval->value.as()}; + ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey(); + return call->args; +} + +enum class StmtType { global_copy, local_copy, compute }; + +/*! Returns the type of the given statement */ +StmtType GetStmtType(const Stmt& stmt) { + Array args{GetStmtArgs(stmt)}; + if (args[0].as()->value == "ethosu_copy") { + if (args[3].as()->buffer.scope() == "global") { + return StmtType::global_copy; + } else { + return StmtType::local_copy; + } + } + return StmtType::compute; +} +/*! Returns the buffer read my the given copy statement */ +Buffer GetCopyReadBuffer(const Stmt& stmt) { + Array args{GetStmtArgs(stmt)}; + return args[1].as()->buffer; +} + +/*! Returns the buffer written my the given copy statement */ +Buffer GetCopyWriteBuffer(const Stmt& stmt) { + Array args{GetStmtArgs(stmt)}; + return args[3].as()->buffer; +} + +/*! Returns the length of the given copy statement */ +int64_t GetCopyLength(const Stmt& stmt) { + Array args{GetStmtArgs(stmt)}; + return args[2].as()->value; +} + +/*! Returns the cycles of the given statement */ +int64_t GetStmtCycles(const Stmt& stmt) { + auto attr{stmt.as()}; + if (attr && attr->attr_key == "pragma_compute_cycles_hint") { + int64_t cycles{Downcast(attr->value)->value}; + return cycles; + } + return 0; +} +} // namespace + /*! * \brief This mutator moves allocates to the top of the body of the main * function. @@ -155,9 +213,9 @@ class CopyComputeReorderingMutator : public StmtExprMutator { // Each copy statement to a buffer with global scope is moved up // at most `_max_copy_movements` times. for (size_t index = 0; index < new_seq.size(); ++index) { - if (stmt_is_global_copy(new_seq[index])) { + if (GetStmtType(new_seq[index]) == StmtType::global_copy) { int lower = std::max(0, static_cast(index) - _max_copy_movements); - for (int i = index; i > lower && !stmt_is_copy(new_seq[i - 1]); --i) { + for (int i = index; i > lower && (GetStmtType(new_seq[i - 1]) == StmtType::compute); --i) { std::swap(new_seq[i - 1], new_seq[i]); } } @@ -168,32 +226,6 @@ class CopyComputeReorderingMutator : public StmtExprMutator { return Stmt{seq_stmt_node}; } - tvm::runtime::Array get_stmt_args(const Stmt& stmt) { - Stmt eval_stmt = stmt; - if (const auto* attr_stmt = eval_stmt.as()) { - eval_stmt = attr_stmt->body; - } - - auto eval_node{eval_stmt.as()}; - ICHECK(eval_node) << "Expected statement to be an evaluate node, but was " - << eval_stmt->GetTypeKey(); - auto call_node{eval_node->value.as()}; - ICHECK(call_node) << "Expected expression to be a call node, but was " - << eval_node->value->GetTypeKey(); - return call_node->args; - } - - bool stmt_is_copy(const Stmt& stmt) { - auto args{get_stmt_args(stmt)}; - return args[0].as()->value == "ethosu_copy"; - } - - bool stmt_is_global_copy(const Stmt& stmt) { - auto args{get_stmt_args(stmt)}; - return args[0].as()->value == "ethosu_copy" && - args[3].as()->buffer.scope() == "global"; - } - /*! The maximum number of movements allowed for a copy. */ int _max_copy_movements; }; @@ -225,130 +257,94 @@ TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering") .set_body_typed(CopyComputeReordering); /*! - * \brief This pass looks for the constants used by each compute operator - * and merges them into a single buffer. - * Constants written to a buffer with local scope are not merged. + * \brief This mutator removes all allocates. */ -class MergeConstantsMutator : public StmtExprMutator { +class RemoveAllocatesMutator : public StmtExprMutator { public: - MergeConstantsMutator() {} - - PrimFunc operator()(PrimFunc main_func, const Map& const_dict) { - // Analyze - Stmt new_body{this->VisitStmt(main_func->body)}; - - // Rewrite - analyze = false; - new_body = rewrite_prim_func_body(new_body); - std::set params_to_delete{}; - auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, ¶ms_to_delete)}; - auto new_params{make_new_params(main_func->params, params_to_delete)}; - - // Make the new const dict - auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)}; - auto buffers_to_merge{ - get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)}; - auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)}; - - // Make the new prim func + PrimFunc operator()(PrimFunc main_func) { auto prim_func_node{main_func.CopyOnWrite()}; - prim_func_node->body = std::move(new_body); - prim_func_node->buffer_map = std::move(new_buffer_map); - prim_func_node->params = std::move(new_params); - prim_func_node->preflattened_buffer_map = {}; - PrimFunc f{GetRef(prim_func_node)}; - - // Add the new const dict as an attribute - f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict); - - return f; + prim_func_node->body = this->VisitStmt(main_func->body); + return GetRef(prim_func_node); } private: - /*! Indicates whether the pass is analyzing or rewriting */ - bool analyze = true; - - /*! A stack to store allocates as they are visited. */ - std::vector allocates{}; - - /*! A list that contains in the i-th position the write buffer of the i-th statement - * if that statement is a copy to a buffer with global scope */ - std::vector> copy_write_buffers{}; + Stmt VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); } +}; - /*! Maps a copy's write buffer to an index representing the - * new buffer and an offset in that buffer */ - std::map> - old_to_new_write_buffer{}; +/*! + * \brief This extractor collects information used by the MergeConstantsMutator + */ +class MergeConstantsInfoExtractor : public StmtExprVisitor { + public: + class Info { + public: + /*! A stack to store allocates as they are visited. */ + std::vector allocates{}; - /*! Maps an index representing a new buffer to the length of that buffer */ - std::map new_buffers_length{}; + /*! A list that contains in the i-th position the write buffer of the i-th statement + * if that statement is a copy to a buffer with global scope */ + std::vector> copy_write_buffers{}; - /*! Maps an index representing a new buffer to the new buffer */ - std::map new_buffers{}; + /*! Maps a copy's write buffer to an index representing the + * new buffer and an offset in that buffer */ + std::unordered_map> + old_to_new_write_buffer{}; - /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */ - std::map cycle_counts{}; + /*! Maps an index representing a new buffer to the length of that buffer */ + std::unordered_map new_buffers_length{}; - /*! Maps a copy's read buffer to the new copy's read buffer */ - std::map old_to_new_read_buffers{}; - - /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer - */ - std::map> buffers_to_merge{}; + /*! Maps an index representing a new buffer to the cycless needed to copy that buffer */ + std::unordered_map cycless{}; + }; - /*! A set of buffers to delete */ - std::set buffers_to_delete{}; + Info operator()(PrimFunc main_func) { + this->VisitStmt(main_func->body); + return std::move(_info); + } - // Visit + private: + /*! The information collected by this extractor */ + Info _info{}; - Stmt VisitStmt_(const AllocateNode* op) override { - if (analyze) { - allocates.push_back(GetRef(op)); - return VisitStmt(op->body); - } else { - auto allocate{CopyOnWrite(op)}; - allocate->body = this->VisitStmt(op->body); - return Stmt(allocate); - } + void VisitStmt_(const AllocateNode* op) override { + _info.allocates.push_back(GetRef(op)); + VisitStmt(op->body); } - Stmt VisitStmt_(const SeqStmtNode* op) override { + void VisitStmt_(const SeqStmtNode* op) override { if (op->size() <= 1) { - return StmtExprMutator::VisitStmt_(op); + StmtExprVisitor::VisitStmt_(op); + return; } - return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op); - } - Stmt analyze_seq_stmt(const SeqStmtNode* op) { auto seq_stmt{GetRef(op)}; - for (size_t i = 0; i < seq_stmt.size(); ++i) { Stmt stmt{seq_stmt[i]}; - - switch (get_stmt_type(stmt)) { + switch (GetStmtType(stmt)) { case StmtType::global_copy: { - Buffer write_buffer{get_copy_write_buffer(stmt)}; - copy_write_buffers.push_back(write_buffer); - old_to_new_write_buffer[write_buffer] = std::make_pair(-1, -1); + Buffer write_buffer{GetCopyWriteBuffer(stmt)}; + _info.copy_write_buffers.push_back(write_buffer); + _info.old_to_new_write_buffer[write_buffer.as()] = std::make_pair(-1, -1); break; } case StmtType::local_copy: { - copy_write_buffers.push_back(Optional{}); + _info.copy_write_buffers.push_back(Optional{}); break; } case StmtType::compute: { - copy_write_buffers.push_back(Optional{}); - auto buffers{get_copied_buffers_used_by_stmt(stmt)}; + _info.copy_write_buffers.push_back(Optional{}); + std::vector buffers{GetCopiedBuffersUsedByStmt(stmt)}; if (buffers.empty()) { continue; } - new_buffers_length[i] = 0; - for (auto buffer : buffers) { + _info.new_buffers_length[i] = 0; + for (Buffer buffer : buffers) { for (size_t j{i - 1}; j >= 0; --j) { - if (copy_write_buffers[j] == buffer) { - old_to_new_write_buffer[buffer] = std::make_pair(i, new_buffers_length[i]); - new_buffers_length[i] += get_copy_length(seq_stmt[j]); - cycle_counts[i] += get_stmt_cycle_counts(seq_stmt[j]); + if (_info.copy_write_buffers[j] == buffer) { + _info.old_to_new_write_buffer[buffer.as()] = + std::make_pair(i, _info.new_buffers_length[i]); + _info.new_buffers_length[i] += GetCopyLength(seq_stmt[j]); + _info.cycless[i] += GetStmtCycles(seq_stmt[j]); break; } } @@ -357,37 +353,109 @@ class MergeConstantsMutator : public StmtExprMutator { } } } - return std::move(seq_stmt); } - Stmt rewrite_prim_func_body(Stmt body) { - std::map var_to_allocate{}; + /*! Get all buffers written by copies and used by a given statement */ + std::vector GetCopiedBuffersUsedByStmt(const Stmt& stmt) { + std::vector buffers{}; + for (PrimExpr arg : GetStmtArgs(stmt)) { + if (auto buffer_load = arg.as()) { + Buffer buffer{buffer_load->buffer}; + // Check if the buffer has already been added + if (std::find(buffers.begin(), buffers.end(), buffer) == buffers.end()) { + // Check if the buffer is copied + if (_info.old_to_new_write_buffer.count(buffer.as())) { + buffers.push_back(buffer); + } + } + } + } + return buffers; + } +}; + +/*! + * \brief This mutator looks for the constants used by each compute operator + * and merges them into a single buffer. + * Constants written to a buffer with local scope are not merged. + */ +class MergeConstantsMutator : public StmtExprMutator { + public: + MergeConstantsMutator(MergeConstantsInfoExtractor::Info info) : _info{std::move(info)} {} + + PrimFunc operator()(PrimFunc main_func, const Map& const_dict) { + // Rewrite + Stmt new_body = RewritePrimFuncBody(main_func->body); + std::unordered_set params_to_delete{}; + Map new_buffer_map{MakeNewBufferMap(main_func->buffer_map, ¶ms_to_delete)}; + Array new_params{MakeNewParams(main_func->params, params_to_delete)}; + + // Make the new const dict + Array> args_to_merge{GetArgsToMerge(main_func->buffer_map, main_func->params)}; + Array> buffers_to_merge{ + GetArgsToMergeWithoutArgsNotInConstDict(args_to_merge, const_dict)}; + Map new_const_dict{MakeNewConstDict(buffers_to_merge, const_dict)}; + + // Make the new prim func + auto prim_func_node{main_func.CopyOnWrite()}; + prim_func_node->body = std::move(new_body); + prim_func_node->buffer_map = std::move(new_buffer_map); + prim_func_node->params = std::move(new_params); + prim_func_node->preflattened_buffer_map = {}; + PrimFunc f{GetRef(prim_func_node)}; + + // Add the new const dict as an attribute + f = WithAttr(std::move(f), "ethos-u.const_dict", new_const_dict); + + return f; + } + + private: + /*! The information collected by the MergeConstantsInfoExtractor */ + MergeConstantsInfoExtractor::Info _info; + + /*! Maps an index representing a new buffer to the new buffer */ + std::unordered_map new_buffers{}; + + /*! Maps a copy's read buffer to the new copy's read buffer */ + std::unordered_map old_to_new_read_buffers{}; + + /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer + */ + std::unordered_map> buffers_to_merge{}; + + /*! A set of buffers to delete */ + std::unordered_set buffers_to_delete{}; + + Stmt RewritePrimFuncBody(Stmt body) { + std::unordered_map var_to_allocate{}; // Rewrite old allocates - std::set buffer_vars{get_vars_for_written_copy_buffers()}; - for (auto it{allocates.rbegin()}; it != allocates.rend(); ++it) { + std::unordered_set buffer_vars{GetVarsForWrittenCopyBuffers()}; + for (auto it{_info.allocates.rbegin()}; it != _info.allocates.rend(); ++it) { Allocate alloc{*it}; var_to_allocate[alloc->buffer_var.get()] = alloc; - if (buffer_vars.count(alloc->buffer_var) == 0) { + if (buffer_vars.count(alloc->buffer_var.as()) == 0) { body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body, alloc->annotations, alloc->span); } } // Rewrite new allocates - for (auto it{copy_write_buffers.rbegin()}; it != copy_write_buffers.rend(); ++it) { - if (auto buffer_opt = *it) { + for (auto it{_info.copy_write_buffers.rbegin()}; it != _info.copy_write_buffers.rend(); ++it) { + if (Optional buffer_opt = *it) { Buffer old_write_buffer{buffer_opt.value()}; - int new_buffer_index{old_to_new_write_buffer[old_write_buffer].first}; + int new_buffer_index{ + _info.old_to_new_write_buffer[old_write_buffer.as()].first}; // Check if the allocate has already been created if (new_buffers.count(new_buffer_index) == 0) { BufferNode* new_buffer{old_write_buffer.CopyOnWrite()}; - new_buffer->shape = {new_buffers_length[new_buffer_index]}; + new_buffer->shape = {_info.new_buffers_length[new_buffer_index]}; new_buffers[new_buffer_index] = GetRef(new_buffer); - auto old_allocate{var_to_allocate[old_write_buffer->data.get()]}; + Allocate old_allocate{var_to_allocate[old_write_buffer->data.get()]}; body = Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, tir::const_true(), body, old_allocate->annotations, old_allocate->span); } @@ -398,25 +466,35 @@ class MergeConstantsMutator : public StmtExprMutator { return this->VisitStmt(body); } - Stmt rewrite_seq_stmt(const SeqStmtNode* op) { - Array new_seq{}; + Stmt VisitStmt_(const AllocateNode* op) override { + auto allocate{CopyOnWrite(op)}; + allocate->body = this->VisitStmt(op->body); + return Stmt(allocate); + } - auto seq_stmt{GetRef(op)}; + Stmt VisitStmt_(const SeqStmtNode* op) override { + if (op->size() <= 1) { + return StmtExprMutator::VisitStmt_(op); + } + + Array new_seq{}; + SeqStmt seq_stmt{GetRef(op)}; for (size_t i{0}; i < seq_stmt.size(); ++i) { Stmt stmt{seq_stmt[i]}; - switch (get_stmt_type(stmt)) { + switch (GetStmtType(stmt)) { case StmtType::global_copy: { - Buffer old_write_buffer{copy_write_buffers[i].value()}; - auto pair{old_to_new_write_buffer[old_write_buffer]}; - auto new_buffer_index{pair.first}; - auto new_buffer_offset{pair.second}; - update_buffers_to_merge_and_delete(stmt, new_buffer_index, new_buffer_offset); - - if (!is_copy_to_be_deleted(new_buffer_offset)) { - auto cycle_counts{get_merged_cycle_counts(new_buffer_index)}; - new_seq.push_back(make_new_stmt( - stmt, make_new_copy_args(stmt, old_write_buffer, new_buffer_index), cycle_counts)); + Buffer old_write_buffer{_info.copy_write_buffers[i].value()}; + std::pair pair{ + _info.old_to_new_write_buffer[old_write_buffer.as()]}; + int new_buffer_index{pair.first}; + int new_buffer_offset{pair.second}; + UpdateBuffersToMergeAndDelete(stmt, new_buffer_index, new_buffer_offset); + + if (!IsCopyToBeDeleted(new_buffer_offset)) { + Optional cycless{GetMergedCycles(new_buffer_index)}; + new_seq.push_back(MakeNewStmt( + stmt, MakeNewCopyArgs(stmt, old_write_buffer, new_buffer_index), cycless)); } break; } @@ -425,7 +503,7 @@ class MergeConstantsMutator : public StmtExprMutator { break; } case StmtType::compute: { - new_seq.push_back(make_new_stmt(stmt, make_new_compute_args(stmt))); + new_seq.push_back(MakeNewStmt(stmt, MakeNewComputeArgs(stmt))); break; } } @@ -433,101 +511,40 @@ class MergeConstantsMutator : public StmtExprMutator { return SeqStmt(new_seq, op->span); } - enum class StmtType { global_copy, local_copy, compute }; - - StmtType get_stmt_type(const Stmt& stmt) { - auto args{get_stmt_args(stmt)}; - if (args[0].as()->value == "ethosu_copy") { - if (args[3].as()->buffer.scope() == "global") { - return StmtType::global_copy; - } else { - return StmtType::local_copy; - } - } - return StmtType::compute; - } - - Buffer get_copy_read_buffer(const Stmt& stmt) { - auto args{get_stmt_args(stmt)}; - return args[1].as()->buffer; - } - - Buffer get_copy_write_buffer(const Stmt& stmt) { - auto args{get_stmt_args(stmt)}; - return args[3].as()->buffer; - } - - int64_t get_copy_length(const Stmt& stmt) { - auto args{get_stmt_args(stmt)}; - return args[2].as()->value; - } - - int64_t get_stmt_cycle_counts(const Stmt& stmt) { - auto attr{stmt.as()}; - if (attr && attr->attr_key == "pragma_compute_cycles_hint") { - int64_t cycle_count{Downcast(attr->value)->value}; - return cycle_count; - } - return 0; - } - - std::vector get_copied_buffers_used_by_stmt(const Stmt& stmt) { - std::vector buffers{}; - for (auto arg : get_stmt_args(stmt)) { - if (auto buffer_load = arg.as()) { - auto buffer{buffer_load->buffer}; - // Check if the buffer has already been added - if (std::find(buffers.begin(), buffers.end(), buffer) == buffers.end()) { - // Check if the buffer is copied - if (old_to_new_write_buffer.count(buffer)) { - buffers.push_back(buffer); - } - } - } - } - return buffers; - } - - std::set get_vars_for_written_copy_buffers() { - std::set buffer_vars{}; - std::transform(old_to_new_write_buffer.begin(), old_to_new_write_buffer.end(), + /*! Returns the variables of the buffers written by copies */ + std::unordered_set GetVarsForWrittenCopyBuffers() { + std::unordered_set buffer_vars{}; + std::transform(_info.old_to_new_write_buffer.begin(), _info.old_to_new_write_buffer.end(), std::inserter(buffer_vars, buffer_vars.begin()), - [](auto pair) -> Var { return pair.first->data; }); + [](std::pair> pair) -> const VarNode* { + return pair.first->data.as(); + }); return buffer_vars; } - tvm::runtime::Array get_stmt_args(const Stmt& stmt) { - auto attr{stmt.as()}; - Stmt eval_stmt{attr ? attr->body : stmt}; - auto eval{eval_stmt.as()}; - ICHECK(eval) << "Expected statement to be an evaluate node, but was " - << eval_stmt->GetTypeKey(); - auto call{eval->value.as()}; - ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey(); - return call->args; - } - - Optional get_merged_cycle_counts(int new_buffer_index) { - auto it = cycle_counts.find(new_buffer_index); - if (it != cycle_counts.end()) { + /*! Returns the cycles of the new buffer at the given index */ + Optional GetMergedCycles(int new_buffer_index) { + auto it = _info.cycless.find(new_buffer_index); + if (it != _info.cycless.end()) { return Integer(it->second); } return Optional{}; } - bool is_copy_to_be_deleted(int new_buffer_offset) { return new_buffer_offset > 0; } + /*! Returns true if a copy must be deleted, false otherwise */ + bool IsCopyToBeDeleted(int new_buffer_offset) { return new_buffer_offset > 0; } - Array make_new_copy_args(const Stmt& stmt, const Buffer& old_write_buffer, - int new_buffer_index) { - Array args{get_stmt_args(stmt)}; - auto new_length{new_buffers_length[new_buffer_index]}; + Array MakeNewCopyArgs(const Stmt& stmt, const Buffer& old_write_buffer, + int new_buffer_index) { + Array args{GetStmtArgs(stmt)}; + int new_length{_info.new_buffers_length[new_buffer_index]}; Array new_args{}; for (size_t i = 0; i < args.size(); ++i) { switch (i) { case 1: /* read_address */ { auto buffer_load = args[1].as(); - auto buffer{buffer_load->buffer}; + Buffer buffer{buffer_load->buffer}; Buffer new_buffer{buffer->data, buffer->dtype, {new_length}, @@ -539,7 +556,7 @@ class MergeConstantsMutator : public StmtExprMutator { buffer->buffer_type, buffer->axis_separators, buffer->span}; - old_to_new_read_buffers[buffer] = new_buffer; + old_to_new_read_buffers[buffer.as()] = new_buffer; new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->span)); break; } @@ -548,7 +565,7 @@ class MergeConstantsMutator : public StmtExprMutator { break; } case 3: /* write_address */ { - new_args.push_back(make_new_buffer_load(old_write_buffer, 0, true).value()); + new_args.push_back(MakeNewBufferLoad(old_write_buffer, 0, true).value()); break; } default: @@ -559,13 +576,13 @@ class MergeConstantsMutator : public StmtExprMutator { return new_args; } - Array make_new_compute_args(const Stmt& stmt) { - Array args{get_stmt_args(stmt)}; + Array MakeNewComputeArgs(const Stmt& stmt) { + Array args{GetStmtArgs(stmt)}; Array new_args{}; for (size_t i = 0; i < args.size(); ++i) { if (auto buffer_load = args[i].as()) { - auto new_buffer_load{ - make_new_buffer_load(buffer_load->buffer, buffer_load->indices[0], false) + BufferLoad new_buffer_load{ + MakeNewBufferLoad(buffer_load->buffer, buffer_load->indices[0], false) .value_or(GetRef(buffer_load))}; new_args.push_back(new_buffer_load); } else { @@ -575,8 +592,8 @@ class MergeConstantsMutator : public StmtExprMutator { return new_args; } - Stmt make_new_stmt(const Stmt& stmt, const Array& new_args, - Optional cycle_counts = Optional{}) { + Stmt MakeNewStmt(const Stmt& stmt, const Array& new_args, + Optional cycless = Optional{}) { auto attr{stmt.as()}; Stmt eval_stmt{attr ? attr->body : stmt}; auto eval{eval_stmt.as()}; @@ -590,36 +607,36 @@ class MergeConstantsMutator : public StmtExprMutator { if (attr) { ICHECK(attr->attr_key == "pragma_compute_cycles_hint"); - PrimExpr value = cycle_counts.value_or(attr->value); + PrimExpr value = cycless.value_or(attr->value); return AttrStmt{attr->node, attr->attr_key, value, new_eval, attr->span}; } else { return std::move(new_eval); } } - Optional make_new_buffer_load(const Buffer& write_buffer, const PrimExpr& old_index, - bool only_old_index) { - auto it = old_to_new_write_buffer.find(write_buffer); - if (it != old_to_new_write_buffer.end()) { - auto pair{it->second}; - auto new_buffer_index{pair.first}; - auto new_index{only_old_index ? old_index : (pair.second + old_index)}; + Optional MakeNewBufferLoad(const Buffer& write_buffer, const PrimExpr& old_index, + bool only_old_index) { + auto it = _info.old_to_new_write_buffer.find(write_buffer.as()); + if (it != _info.old_to_new_write_buffer.end()) { + std::pair pair{it->second}; + int new_buffer_index{pair.first}; + PrimExpr new_index{only_old_index ? old_index : (pair.second + old_index)}; return BufferLoad{new_buffers[new_buffer_index], {new_index}}; } return Optional{}; } - Map make_new_buffer_map(const Map& buffer_map, - std::set* params_to_delete) { + Map MakeNewBufferMap(const Map& buffer_map, + std::unordered_set* params_to_delete) { Map new_buffer_map{}; - for (auto pair : buffer_map) { + for (std::pair pair : buffer_map) { Var var{pair.first}; Buffer buffer{pair.second}; - if (buffers_to_delete.count(buffer) == 1) { - params_to_delete->insert(var); - } else if (old_to_new_read_buffers.count(buffer) == 1) { - new_buffer_map.Set(var, old_to_new_read_buffers[buffer]); + if (buffers_to_delete.count(buffer.as()) == 1) { + params_to_delete->insert(var.as()); + } else if (old_to_new_read_buffers.count(buffer.as()) == 1) { + new_buffer_map.Set(var, old_to_new_read_buffers[buffer.as()]); } else { new_buffer_map.Set(var, buffer); } @@ -627,21 +644,21 @@ class MergeConstantsMutator : public StmtExprMutator { return new_buffer_map; } - Array make_new_params(const Array& params, - const std::set& params_to_delete) { + Array MakeNewParams(const Array& params, + const std::unordered_set& params_to_delete) { std::vector new_params{}; - for (auto var : params) { - if (params_to_delete.count(var) == 0) { + for (Var var : params) { + if (params_to_delete.count(var.as()) == 0) { new_params.push_back(var); } } return new_params; } - void update_buffers_to_merge_and_delete(const Stmt& stmt, int new_buffer_index, - int new_buffer_offset) { - Array args{get_stmt_args(stmt)}; - Buffer read_buffer{get_copy_read_buffer(stmt)}; + void UpdateBuffersToMergeAndDelete(const Stmt& stmt, int new_buffer_index, + int new_buffer_offset) { + Array args{GetStmtArgs(stmt)}; + Buffer read_buffer{GetCopyReadBuffer(stmt)}; if (buffers_to_merge.count(new_buffer_index) == 0) { buffers_to_merge[new_buffer_index] = std::vector{read_buffer}; @@ -650,30 +667,30 @@ class MergeConstantsMutator : public StmtExprMutator { } if (new_buffer_offset > 0) { - buffers_to_delete.insert(read_buffer); + buffers_to_delete.insert(read_buffer.as()); } } /*! Returns an array whose elements are the indices of the function arguments to be merged. * Example: if a function has three arguments and the second and the third ones must * be merged then the array is: [[0], [1, 2], [3]] */ - Array> get_args_to_merge(const Map& buffer_map, - const Array& params) { - std::map buffer_to_var{}; - for (auto var_buffer : buffer_map) { - buffer_to_var[var_buffer.second] = var_buffer.first; + Array> GetArgsToMerge(const Map& buffer_map, + const Array& params) { + std::unordered_map buffer_to_var{}; + for (std::pair var_buffer : buffer_map) { + buffer_to_var[var_buffer.second.as()] = var_buffer.first; } - std::map var_to_index{}; + std::unordered_map var_to_index{}; for (int i = 0; i < static_cast(params.size()); ++i) { - var_to_index[params[i]] = i; + var_to_index[params[i].as()] = i; } std::vector> vector{}; - for (auto index_vector : buffers_to_merge) { + for (std::pair> index_vector : buffers_to_merge) { std::vector indices{}; - for (auto buffer : index_vector.second) { - auto var{buffer_to_var[buffer]}; + for (Buffer buffer : index_vector.second) { + const VarNode* var{buffer_to_var[buffer.as()].as()}; IntImm index{DataType::Int(64), var_to_index[var]}; var_to_index.erase(var); auto it = std::find_if(indices.begin(), indices.end(), @@ -685,7 +702,7 @@ class MergeConstantsMutator : public StmtExprMutator { vector.push_back(Array{indices}); } - for (auto var_index : var_to_index) { + for (std::pair var_index : var_to_index) { vector.push_back(Array{IntImm(DataType::Int(64), var_index.second)}); } std::sort(vector.begin(), vector.end(), @@ -693,10 +710,10 @@ class MergeConstantsMutator : public StmtExprMutator { return vector; } - Array> get_args_to_merge_without_args_not_in_const_dict( + Array> GetArgsToMergeWithoutArgsNotInConstDict( const Array>& args_to_merge, const Map& const_dict) { Array> new_args_to_merge{}; - for (auto args : args_to_merge) { + for (Array args : args_to_merge) { IntImm key{args[0]}; auto it = std::find_if(const_dict.begin(), const_dict.end(), [&](std::pair pair) { @@ -709,30 +726,30 @@ class MergeConstantsMutator : public StmtExprMutator { return new_args_to_merge; } - Map make_new_const_dict(const Array>& args_to_merge, - Map const_dict) { + Map MakeNewConstDict(const Array>& args_to_merge, + Map const_dict) { Map new_const_dict{}; if (args_to_merge.size() == 0) { return new_const_dict; } int64_t key = args_to_merge[0][0]->value; - for (auto args : args_to_merge) { + for (Array args : args_to_merge) { int64_t size = 0; - for (auto arg : args) { + for (IntImm arg : args) { auto it = std::find_if(const_dict.begin(), const_dict.end(), [&](auto pair) { return pair.first->value == arg->value; }); - auto arg_constant{(*it).second}; + runtime::NDArray arg_constant{(*it).second}; size += runtime::GetDataSize(*arg_constant.operator->()); } runtime::NDArray constant = runtime::NDArray::Empty({size}, DataType::UInt(8), {kDLCPU, 0}); size_t offset = 0; - for (auto arg : args) { + for (IntImm arg : args) { auto it = std::find_if(const_dict.begin(), const_dict.end(), [&](auto pair) { return pair.first->value == arg->value; }); - auto arg_constant{(*it).second}; + runtime::NDArray arg_constant{(*it).second}; size_t nbytes = runtime::GetDataSize(*arg_constant.operator->()); arg_constant.CopyToBytes(static_cast(constant->data) + offset, nbytes); offset += nbytes; @@ -755,10 +772,13 @@ tvm::transform::Pass MergeConstants() { ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main")) << "Expected a single primitive function called 'main'. Please run the " "MergeConstants pass in conjunction with the LowerToTIR() pass."; - auto const_dict{ - f->attrs.GetAttr("ethos-u.const-dict", Optional>{})}; - ICHECK(const_dict) << "Expected a ethos-u.const-dict attribute"; - return MergeConstantsMutator()(f, const_dict.value()); + Optional> const_dict{ + f->attrs.GetAttr("ethos-u.const_dict", Optional>{})}; + ICHECK(const_dict) << "Expected a ethos-u.const_dict attribute"; + + MergeConstantsInfoExtractor::Info info{MergeConstantsInfoExtractor()(f)}; + f = RemoveAllocatesMutator()(f); + return MergeConstantsMutator(info)(f, const_dict.value()); }; return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.MergeConstants", {}); @@ -767,7 +787,7 @@ tvm::transform::Pass MergeConstants() { TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.MergeConstants").set_body_typed(MergeConstants); /*! - * \brief This pass removes the ethos-u.const-dict attribute + * \brief This pass removes the ethos-u.const_dict attribute * \return tvm::transform::Pass */ class RemoveConstDictAttributeMutator : public StmtExprMutator { @@ -775,7 +795,7 @@ class RemoveConstDictAttributeMutator : public StmtExprMutator { RemoveConstDictAttributeMutator() {} PrimFunc operator()(PrimFunc main_func) { - return WithoutAttr(std::move(main_func), "ethos-u.const-dict"); + return WithoutAttr(std::move(main_func), "ethos-u.const_dict"); } }; From 5ca38dcb9a4a74e7daa1661f67d5ed574fe6267e Mon Sep 17 00:00:00 2001 From: Nicola Lancellotti Date: Fri, 8 Jul 2022 18:39:39 +0000 Subject: [PATCH 4/4] Fix lint error Change-Id: Ie5caf506337de01e169d6f422e4682eefbd93241 --- src/tir/contrib/ethosu/passes.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 8ec3185b7d88..fe7055a9e0d8 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -381,7 +381,7 @@ class MergeConstantsInfoExtractor : public StmtExprVisitor { */ class MergeConstantsMutator : public StmtExprMutator { public: - MergeConstantsMutator(MergeConstantsInfoExtractor::Info info) : _info{std::move(info)} {} + explicit MergeConstantsMutator(MergeConstantsInfoExtractor::Info info) : _info{std::move(info)} {} PrimFunc operator()(PrimFunc main_func, const Map& const_dict) { // Rewrite