diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index 0fd82378c300..85c6df4c7d0c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -90,6 +90,10 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod = tvm.tir.transform.RemoveNoOp()(mod) mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod) mod = ethosu_passes.HoistAllocates()(mod) + # MergeConstant pass currently does not support striped schedules. + # It requires further investigation. + if not util.is_striping_enabled(): + mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod) mod = ethosu_passes.CopyComputeReordering()(mod) # When striping is enabled and if storage_rewrite is not run diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 76726132e05d..c0b017e703ce 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -938,3 +938,38 @@ def CopyComputeReordering(max_copy_movements: Optional[int] = None) -> tvm.IRMod The new module with copy and compute nodes reordered. """ return _ffi_api.CopyComputeReordering(max_copy_movements) + + +def MergeConstants(const_dict): + """ + This pass looks for the constants used by each compute operator + and merges them into a single buffer. + Constants written to a buffer with local scope are not merged. + """ + + def _merge_constants(mod): + nonlocal const_dict + try: + mod["main"] + except: + raise tvm.TVMError( + "Expected a single primitive function called 'main'. " + "Please run the MergeConstants pass in conjunction with the LowerToTIR() pass." + ) + + new_const_dict = {} + for param in const_dict.keys(): + new_const_dict[tvm.tir.IntImm("int64", param)] = tvm.nd.array(const_dict[param]) + mod["main"] = mod["main"].with_attr("ethos-u.const_dict", new_const_dict) + + mod = _ffi_api.MergeConstants()(mod) + const_dict = mod["main"].attrs["ethos-u.const_dict"] + mod = _ffi_api.RemoveConstDictAttribute()(mod) + + new_const_dict = {} + for param in const_dict.keys(): + new_const_dict[int(param)] = const_dict[param].numpy() + + return mod, new_const_dict + + return _merge_constants diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 09c359c55abb..fe7055a9e0d8 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -24,10 +24,13 @@ */ #include #include +#include #include #include #include +#include +#include namespace tvm { @@ -42,6 +45,62 @@ namespace tir { namespace contrib { namespace ethosu { +namespace { + +/*! Returns the arguments of the given statement */ +Array GetStmtArgs(const Stmt& stmt) { + auto attr{stmt.as()}; + Stmt eval_stmt{attr ? attr->body : stmt}; + auto eval{eval_stmt.as()}; + ICHECK(eval) << "Expected statement to be an evaluate node, but was " << eval_stmt->GetTypeKey(); + auto call{eval->value.as()}; + ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey(); + return call->args; +} + +enum class StmtType { global_copy, local_copy, compute }; + +/*! Returns the type of the given statement */ +StmtType GetStmtType(const Stmt& stmt) { + Array args{GetStmtArgs(stmt)}; + if (args[0].as()->value == "ethosu_copy") { + if (args[3].as()->buffer.scope() == "global") { + return StmtType::global_copy; + } else { + return StmtType::local_copy; + } + } + return StmtType::compute; +} +/*! Returns the buffer read my the given copy statement */ +Buffer GetCopyReadBuffer(const Stmt& stmt) { + Array args{GetStmtArgs(stmt)}; + return args[1].as()->buffer; +} + +/*! Returns the buffer written my the given copy statement */ +Buffer GetCopyWriteBuffer(const Stmt& stmt) { + Array args{GetStmtArgs(stmt)}; + return args[3].as()->buffer; +} + +/*! Returns the length of the given copy statement */ +int64_t GetCopyLength(const Stmt& stmt) { + Array args{GetStmtArgs(stmt)}; + return args[2].as()->value; +} + +/*! Returns the cycles of the given statement */ +int64_t GetStmtCycles(const Stmt& stmt) { + auto attr{stmt.as()}; + if (attr && attr->attr_key == "pragma_compute_cycles_hint") { + int64_t cycles{Downcast(attr->value)->value}; + return cycles; + } + return 0; +} +} // namespace + /*! * \brief This mutator moves allocates to the top of the body of the main * function. @@ -154,9 +213,9 @@ class CopyComputeReorderingMutator : public StmtExprMutator { // Each copy statement to a buffer with global scope is moved up // at most `_max_copy_movements` times. for (size_t index = 0; index < new_seq.size(); ++index) { - if (stmt_is_global_copy(new_seq[index])) { + if (GetStmtType(new_seq[index]) == StmtType::global_copy) { int lower = std::max(0, static_cast(index) - _max_copy_movements); - for (int i = index; i > lower && !stmt_is_copy(new_seq[i - 1]); --i) { + for (int i = index; i > lower && (GetStmtType(new_seq[i - 1]) == StmtType::compute); --i) { std::swap(new_seq[i - 1], new_seq[i]); } } @@ -167,32 +226,6 @@ class CopyComputeReorderingMutator : public StmtExprMutator { return Stmt{seq_stmt_node}; } - tvm::runtime::Array get_stmt_args(const Stmt& stmt) { - Stmt eval_stmt = stmt; - if (const auto* attr_stmt = eval_stmt.as()) { - eval_stmt = attr_stmt->body; - } - - auto eval_node{eval_stmt.as()}; - ICHECK(eval_node) << "Expected statement to be an evaluate node, but was " - << eval_stmt->GetTypeKey(); - auto call_node{eval_node->value.as()}; - ICHECK(call_node) << "Expected expression to be a call node, but was " - << eval_node->value->GetTypeKey(); - return call_node->args; - } - - bool stmt_is_copy(const Stmt& stmt) { - auto args{get_stmt_args(stmt)}; - return args[0].as()->value == "ethosu_copy"; - } - - bool stmt_is_global_copy(const Stmt& stmt) { - auto args{get_stmt_args(stmt)}; - return args[0].as()->value == "ethosu_copy" && - args[3].as()->buffer.scope() == "global"; - } - /*! The maximum number of movements allowed for a copy. */ int _max_copy_movements; }; @@ -223,6 +256,560 @@ tvm::transform::Pass CopyComputeReordering(Optional max_copy_movements) TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering") .set_body_typed(CopyComputeReordering); +/*! + * \brief This mutator removes all allocates. + */ +class RemoveAllocatesMutator : public StmtExprMutator { + public: + PrimFunc operator()(PrimFunc main_func) { + auto prim_func_node{main_func.CopyOnWrite()}; + prim_func_node->body = this->VisitStmt(main_func->body); + return GetRef(prim_func_node); + } + + private: + Stmt VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); } +}; + +/*! + * \brief This extractor collects information used by the MergeConstantsMutator + */ +class MergeConstantsInfoExtractor : public StmtExprVisitor { + public: + class Info { + public: + /*! A stack to store allocates as they are visited. */ + std::vector allocates{}; + + /*! A list that contains in the i-th position the write buffer of the i-th statement + * if that statement is a copy to a buffer with global scope */ + std::vector> copy_write_buffers{}; + + /*! Maps a copy's write buffer to an index representing the + * new buffer and an offset in that buffer */ + std::unordered_map> + old_to_new_write_buffer{}; + + /*! Maps an index representing a new buffer to the length of that buffer */ + std::unordered_map new_buffers_length{}; + + /*! Maps an index representing a new buffer to the cycless needed to copy that buffer */ + std::unordered_map cycless{}; + }; + + Info operator()(PrimFunc main_func) { + this->VisitStmt(main_func->body); + return std::move(_info); + } + + private: + /*! The information collected by this extractor */ + Info _info{}; + + void VisitStmt_(const AllocateNode* op) override { + _info.allocates.push_back(GetRef(op)); + VisitStmt(op->body); + } + + void VisitStmt_(const SeqStmtNode* op) override { + if (op->size() <= 1) { + StmtExprVisitor::VisitStmt_(op); + return; + } + + auto seq_stmt{GetRef(op)}; + for (size_t i = 0; i < seq_stmt.size(); ++i) { + Stmt stmt{seq_stmt[i]}; + switch (GetStmtType(stmt)) { + case StmtType::global_copy: { + Buffer write_buffer{GetCopyWriteBuffer(stmt)}; + _info.copy_write_buffers.push_back(write_buffer); + _info.old_to_new_write_buffer[write_buffer.as()] = std::make_pair(-1, -1); + break; + } + case StmtType::local_copy: { + _info.copy_write_buffers.push_back(Optional{}); + break; + } + case StmtType::compute: { + _info.copy_write_buffers.push_back(Optional{}); + std::vector buffers{GetCopiedBuffersUsedByStmt(stmt)}; + if (buffers.empty()) { + continue; + } + _info.new_buffers_length[i] = 0; + for (Buffer buffer : buffers) { + for (size_t j{i - 1}; j >= 0; --j) { + if (_info.copy_write_buffers[j] == buffer) { + _info.old_to_new_write_buffer[buffer.as()] = + std::make_pair(i, _info.new_buffers_length[i]); + _info.new_buffers_length[i] += GetCopyLength(seq_stmt[j]); + _info.cycless[i] += GetStmtCycles(seq_stmt[j]); + break; + } + } + } + break; + } + } + } + } + + /*! Get all buffers written by copies and used by a given statement */ + std::vector GetCopiedBuffersUsedByStmt(const Stmt& stmt) { + std::vector buffers{}; + for (PrimExpr arg : GetStmtArgs(stmt)) { + if (auto buffer_load = arg.as()) { + Buffer buffer{buffer_load->buffer}; + // Check if the buffer has already been added + if (std::find(buffers.begin(), buffers.end(), buffer) == buffers.end()) { + // Check if the buffer is copied + if (_info.old_to_new_write_buffer.count(buffer.as())) { + buffers.push_back(buffer); + } + } + } + } + return buffers; + } +}; + +/*! + * \brief This mutator looks for the constants used by each compute operator + * and merges them into a single buffer. + * Constants written to a buffer with local scope are not merged. + */ +class MergeConstantsMutator : public StmtExprMutator { + public: + explicit MergeConstantsMutator(MergeConstantsInfoExtractor::Info info) : _info{std::move(info)} {} + + PrimFunc operator()(PrimFunc main_func, const Map& const_dict) { + // Rewrite + Stmt new_body = RewritePrimFuncBody(main_func->body); + std::unordered_set params_to_delete{}; + Map new_buffer_map{MakeNewBufferMap(main_func->buffer_map, ¶ms_to_delete)}; + Array new_params{MakeNewParams(main_func->params, params_to_delete)}; + + // Make the new const dict + Array> args_to_merge{GetArgsToMerge(main_func->buffer_map, main_func->params)}; + Array> buffers_to_merge{ + GetArgsToMergeWithoutArgsNotInConstDict(args_to_merge, const_dict)}; + Map new_const_dict{MakeNewConstDict(buffers_to_merge, const_dict)}; + + // Make the new prim func + auto prim_func_node{main_func.CopyOnWrite()}; + prim_func_node->body = std::move(new_body); + prim_func_node->buffer_map = std::move(new_buffer_map); + prim_func_node->params = std::move(new_params); + prim_func_node->preflattened_buffer_map = {}; + PrimFunc f{GetRef(prim_func_node)}; + + // Add the new const dict as an attribute + f = WithAttr(std::move(f), "ethos-u.const_dict", new_const_dict); + + return f; + } + + private: + /*! The information collected by the MergeConstantsInfoExtractor */ + MergeConstantsInfoExtractor::Info _info; + + /*! Maps an index representing a new buffer to the new buffer */ + std::unordered_map new_buffers{}; + + /*! Maps a copy's read buffer to the new copy's read buffer */ + std::unordered_map old_to_new_read_buffers{}; + + /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer + */ + std::unordered_map> buffers_to_merge{}; + + /*! A set of buffers to delete */ + std::unordered_set buffers_to_delete{}; + + Stmt RewritePrimFuncBody(Stmt body) { + std::unordered_map var_to_allocate{}; + + // Rewrite old allocates + std::unordered_set buffer_vars{GetVarsForWrittenCopyBuffers()}; + for (auto it{_info.allocates.rbegin()}; it != _info.allocates.rend(); ++it) { + Allocate alloc{*it}; + var_to_allocate[alloc->buffer_var.get()] = alloc; + if (buffer_vars.count(alloc->buffer_var.as()) == 0) { + body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body, + alloc->annotations, alloc->span); + } + } + + // Rewrite new allocates + for (auto it{_info.copy_write_buffers.rbegin()}; it != _info.copy_write_buffers.rend(); ++it) { + if (Optional buffer_opt = *it) { + Buffer old_write_buffer{buffer_opt.value()}; + int new_buffer_index{ + _info.old_to_new_write_buffer[old_write_buffer.as()].first}; + + // Check if the allocate has already been created + if (new_buffers.count(new_buffer_index) == 0) { + BufferNode* new_buffer{old_write_buffer.CopyOnWrite()}; + new_buffer->shape = {_info.new_buffers_length[new_buffer_index]}; + + new_buffers[new_buffer_index] = GetRef(new_buffer); + + Allocate old_allocate{var_to_allocate[old_write_buffer->data.get()]}; + body = Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, tir::const_true(), + body, old_allocate->annotations, old_allocate->span); + } + } + } + + // Rewrite operators + return this->VisitStmt(body); + } + + Stmt VisitStmt_(const AllocateNode* op) override { + auto allocate{CopyOnWrite(op)}; + allocate->body = this->VisitStmt(op->body); + return Stmt(allocate); + } + + Stmt VisitStmt_(const SeqStmtNode* op) override { + if (op->size() <= 1) { + return StmtExprMutator::VisitStmt_(op); + } + + Array new_seq{}; + SeqStmt seq_stmt{GetRef(op)}; + for (size_t i{0}; i < seq_stmt.size(); ++i) { + Stmt stmt{seq_stmt[i]}; + + switch (GetStmtType(stmt)) { + case StmtType::global_copy: { + Buffer old_write_buffer{_info.copy_write_buffers[i].value()}; + std::pair pair{ + _info.old_to_new_write_buffer[old_write_buffer.as()]}; + int new_buffer_index{pair.first}; + int new_buffer_offset{pair.second}; + UpdateBuffersToMergeAndDelete(stmt, new_buffer_index, new_buffer_offset); + + if (!IsCopyToBeDeleted(new_buffer_offset)) { + Optional cycless{GetMergedCycles(new_buffer_index)}; + new_seq.push_back(MakeNewStmt( + stmt, MakeNewCopyArgs(stmt, old_write_buffer, new_buffer_index), cycless)); + } + break; + } + case StmtType::local_copy: { + new_seq.push_back(stmt); + break; + } + case StmtType::compute: { + new_seq.push_back(MakeNewStmt(stmt, MakeNewComputeArgs(stmt))); + break; + } + } + } + return SeqStmt(new_seq, op->span); + } + + /*! Returns the variables of the buffers written by copies */ + std::unordered_set GetVarsForWrittenCopyBuffers() { + std::unordered_set buffer_vars{}; + std::transform(_info.old_to_new_write_buffer.begin(), _info.old_to_new_write_buffer.end(), + std::inserter(buffer_vars, buffer_vars.begin()), + [](std::pair> pair) -> const VarNode* { + return pair.first->data.as(); + }); + return buffer_vars; + } + + /*! Returns the cycles of the new buffer at the given index */ + Optional GetMergedCycles(int new_buffer_index) { + auto it = _info.cycless.find(new_buffer_index); + if (it != _info.cycless.end()) { + return Integer(it->second); + } + return Optional{}; + } + + /*! Returns true if a copy must be deleted, false otherwise */ + bool IsCopyToBeDeleted(int new_buffer_offset) { return new_buffer_offset > 0; } + + Array MakeNewCopyArgs(const Stmt& stmt, const Buffer& old_write_buffer, + int new_buffer_index) { + Array args{GetStmtArgs(stmt)}; + int new_length{_info.new_buffers_length[new_buffer_index]}; + + Array new_args{}; + for (size_t i = 0; i < args.size(); ++i) { + switch (i) { + case 1: /* read_address */ { + auto buffer_load = args[1].as(); + Buffer buffer{buffer_load->buffer}; + Buffer new_buffer{buffer->data, + buffer->dtype, + {new_length}, + buffer->strides, + buffer->elem_offset, + buffer->name, + buffer->data_alignment, + buffer->offset_factor, + buffer->buffer_type, + buffer->axis_separators, + buffer->span}; + old_to_new_read_buffers[buffer.as()] = new_buffer; + new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->span)); + break; + } + case 2: /* length */ { + new_args.push_back(new_length); + break; + } + case 3: /* write_address */ { + new_args.push_back(MakeNewBufferLoad(old_write_buffer, 0, true).value()); + break; + } + default: + new_args.push_back(args[i]); + break; + } + } + return new_args; + } + + Array MakeNewComputeArgs(const Stmt& stmt) { + Array args{GetStmtArgs(stmt)}; + Array new_args{}; + for (size_t i = 0; i < args.size(); ++i) { + if (auto buffer_load = args[i].as()) { + BufferLoad new_buffer_load{ + MakeNewBufferLoad(buffer_load->buffer, buffer_load->indices[0], false) + .value_or(GetRef(buffer_load))}; + new_args.push_back(new_buffer_load); + } else { + new_args.push_back(args[i]); + } + } + return new_args; + } + + Stmt MakeNewStmt(const Stmt& stmt, const Array& new_args, + Optional cycless = Optional{}) { + auto attr{stmt.as()}; + Stmt eval_stmt{attr ? attr->body : stmt}; + auto eval{eval_stmt.as()}; + ICHECK(eval) << "Expected statement to be an evaluate node, but was " + << eval_stmt->GetTypeKey(); + auto call{eval->value.as()}; + ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey(); + + Call new_call{call->dtype, call->op, new_args, call->span}; + Evaluate new_eval{new_call, eval->span}; + + if (attr) { + ICHECK(attr->attr_key == "pragma_compute_cycles_hint"); + PrimExpr value = cycless.value_or(attr->value); + return AttrStmt{attr->node, attr->attr_key, value, new_eval, attr->span}; + } else { + return std::move(new_eval); + } + } + + Optional MakeNewBufferLoad(const Buffer& write_buffer, const PrimExpr& old_index, + bool only_old_index) { + auto it = _info.old_to_new_write_buffer.find(write_buffer.as()); + if (it != _info.old_to_new_write_buffer.end()) { + std::pair pair{it->second}; + int new_buffer_index{pair.first}; + PrimExpr new_index{only_old_index ? old_index : (pair.second + old_index)}; + return BufferLoad{new_buffers[new_buffer_index], {new_index}}; + } + return Optional{}; + } + + Map MakeNewBufferMap(const Map& buffer_map, + std::unordered_set* params_to_delete) { + Map new_buffer_map{}; + for (std::pair pair : buffer_map) { + Var var{pair.first}; + Buffer buffer{pair.second}; + + if (buffers_to_delete.count(buffer.as()) == 1) { + params_to_delete->insert(var.as()); + } else if (old_to_new_read_buffers.count(buffer.as()) == 1) { + new_buffer_map.Set(var, old_to_new_read_buffers[buffer.as()]); + } else { + new_buffer_map.Set(var, buffer); + } + } + return new_buffer_map; + } + + Array MakeNewParams(const Array& params, + const std::unordered_set& params_to_delete) { + std::vector new_params{}; + for (Var var : params) { + if (params_to_delete.count(var.as()) == 0) { + new_params.push_back(var); + } + } + return new_params; + } + + void UpdateBuffersToMergeAndDelete(const Stmt& stmt, int new_buffer_index, + int new_buffer_offset) { + Array args{GetStmtArgs(stmt)}; + Buffer read_buffer{GetCopyReadBuffer(stmt)}; + + if (buffers_to_merge.count(new_buffer_index) == 0) { + buffers_to_merge[new_buffer_index] = std::vector{read_buffer}; + } else { + buffers_to_merge[new_buffer_index].push_back(read_buffer); + } + + if (new_buffer_offset > 0) { + buffers_to_delete.insert(read_buffer.as()); + } + } + + /*! Returns an array whose elements are the indices of the function arguments to be merged. + * Example: if a function has three arguments and the second and the third ones must + * be merged then the array is: [[0], [1, 2], [3]] */ + Array> GetArgsToMerge(const Map& buffer_map, + const Array& params) { + std::unordered_map buffer_to_var{}; + for (std::pair var_buffer : buffer_map) { + buffer_to_var[var_buffer.second.as()] = var_buffer.first; + } + + std::unordered_map var_to_index{}; + for (int i = 0; i < static_cast(params.size()); ++i) { + var_to_index[params[i].as()] = i; + } + + std::vector> vector{}; + for (std::pair> index_vector : buffers_to_merge) { + std::vector indices{}; + for (Buffer buffer : index_vector.second) { + const VarNode* var{buffer_to_var[buffer.as()].as()}; + IntImm index{DataType::Int(64), var_to_index[var]}; + var_to_index.erase(var); + auto it = std::find_if(indices.begin(), indices.end(), + [&](IntImm value) { return value->value == index->value; }); + if (it == indices.end()) { + indices.push_back(index); + } + } + vector.push_back(Array{indices}); + } + + for (std::pair var_index : var_to_index) { + vector.push_back(Array{IntImm(DataType::Int(64), var_index.second)}); + } + std::sort(vector.begin(), vector.end(), + [](Array a, Array b) { return a[0]->value < b[0]->value; }); + return vector; + } + + Array> GetArgsToMergeWithoutArgsNotInConstDict( + const Array>& args_to_merge, const Map& const_dict) { + Array> new_args_to_merge{}; + for (Array args : args_to_merge) { + IntImm key{args[0]}; + auto it = std::find_if(const_dict.begin(), const_dict.end(), + [&](std::pair pair) { + return pair.first->value == key->value; + }); + if (it != const_dict.end()) { + new_args_to_merge.push_back(args); + } + } + return new_args_to_merge; + } + + Map MakeNewConstDict(const Array>& args_to_merge, + Map const_dict) { + Map new_const_dict{}; + if (args_to_merge.size() == 0) { + return new_const_dict; + } + + int64_t key = args_to_merge[0][0]->value; + for (Array args : args_to_merge) { + int64_t size = 0; + for (IntImm arg : args) { + auto it = std::find_if(const_dict.begin(), const_dict.end(), + [&](auto pair) { return pair.first->value == arg->value; }); + runtime::NDArray arg_constant{(*it).second}; + size += runtime::GetDataSize(*arg_constant.operator->()); + } + + runtime::NDArray constant = runtime::NDArray::Empty({size}, DataType::UInt(8), {kDLCPU, 0}); + + size_t offset = 0; + for (IntImm arg : args) { + auto it = std::find_if(const_dict.begin(), const_dict.end(), + [&](auto pair) { return pair.first->value == arg->value; }); + runtime::NDArray arg_constant{(*it).second}; + size_t nbytes = runtime::GetDataSize(*arg_constant.operator->()); + arg_constant.CopyToBytes(static_cast(constant->data) + offset, nbytes); + offset += nbytes; + } + new_const_dict.Set(IntImm(DataType::Int(64), key), constant); + key += 1; + } + return new_const_dict; + } +}; + +/*! + * \brief This pass looks for the constants used by each compute operator + * and merges them into a single buffer. + * Constants written to a buffer with local scope are not merged. + * \return tvm::transform::Pass + */ +tvm::transform::Pass MergeConstants() { + auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) { + ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main")) + << "Expected a single primitive function called 'main'. Please run the " + "MergeConstants pass in conjunction with the LowerToTIR() pass."; + Optional> const_dict{ + f->attrs.GetAttr("ethos-u.const_dict", Optional>{})}; + ICHECK(const_dict) << "Expected a ethos-u.const_dict attribute"; + + MergeConstantsInfoExtractor::Info info{MergeConstantsInfoExtractor()(f)}; + f = RemoveAllocatesMutator()(f); + return MergeConstantsMutator(info)(f, const_dict.value()); + }; + return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.MergeConstants", + {}); +} + +TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.MergeConstants").set_body_typed(MergeConstants); + +/*! + * \brief This pass removes the ethos-u.const_dict attribute + * \return tvm::transform::Pass + */ +class RemoveConstDictAttributeMutator : public StmtExprMutator { + public: + RemoveConstDictAttributeMutator() {} + + PrimFunc operator()(PrimFunc main_func) { + return WithoutAttr(std::move(main_func), "ethos-u.const_dict"); + } +}; + +tvm::transform::Pass RemoveConstDictAttribute() { + auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) { + return RemoveConstDictAttributeMutator()(f); + }; + return tvm::tir::transform::CreatePrimFuncPass( + pass_func, 0, "tir.contrib.ethos-u.RemoveConstDictAttribute", {}); +} + +TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.RemoveConstDictAttribute") + .set_body_typed(RemoveConstDictAttribute); + } // namespace ethosu } // namespace contrib } // namespace tir diff --git a/tests/python/contrib/test_ethosu/cascader/test_integration.py b/tests/python/contrib/test_ethosu/cascader/test_integration.py index 8e1f020861d5..14cc8fbc61cf 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_integration.py +++ b/tests/python/contrib/test_ethosu/cascader/test_integration.py @@ -109,9 +109,8 @@ def test_single_conv_compute_cycles_hint(): for single convolution. """ primfunc = _compile_model(_create_single_conv2d()) - ops = primfunc.body.body.body.seq - - compute_cycles_hints = [2304, 640, 320] + ops = primfunc.body.body.seq + compute_cycles_hints = [2944, 320] for op, compute_cycle_hint in zip(ops, compute_cycles_hints): assert op.attr_key == "pragma_compute_cycles_hint" assert op.value == compute_cycle_hint @@ -123,9 +122,8 @@ def test_double_conv_compute_cycles_hint(): for double convolution. """ primfunc = _compile_model(_create_double_conv2d()) - ops = primfunc.body.body.body.body.body.body.seq - - compute_cycles_hints = [2304, 640, 768, 640, 320, 240] + ops = primfunc.body.body.body.body.seq + compute_cycles_hints = [2944, 1408, 320, 240] for op, compute_cycle_hint in zip(ops, compute_cycles_hints): assert op.attr_key == "pragma_compute_cycles_hint" assert op.value == compute_cycle_hint diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 15b719f33c3f..fd9f373739e1 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -37,34 +37,23 @@ class WeightStreamOnlyU55: def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer1 = T.buffer_decl([128], "uint8") - buffer2 = T.buffer_decl([32], "uint8") - buffer3 = T.buffer_decl([112], "uint8") - buffer4 = T.buffer_decl([32], "uint8") - buffer5 = T.buffer_decl([112], "uint8") - buffer6 = T.buffer_decl([32], "uint8") - buffer7 = T.buffer_decl([112], "uint8") + buffer1 = T.buffer_decl([160], "uint8") + buffer3 = T.buffer_decl([144], "uint8") + buffer5 = T.buffer_decl([144], "uint8") + buffer7 = T.buffer_decl([144], "uint8") buffer8 = T.buffer_decl([32], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body - p1 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - p3 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) - p4 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - buffer9 = T.buffer_decl([112], "uint8", data=p1.data) - T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 128, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 32, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p3[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 32, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 112, buffer9[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 32, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, T.int8(-1), T.int8(-1), 12, p4[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 112, p3[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 32, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 112, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, T.int8(-1), T.int8(-1), 12, p4[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + p1 = T.allocate([160], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([144], "uint8", "global", annotations={"disable_lower_builtin":True}) + buffer9 = T.buffer_decl([144], "uint8", data=p1.data) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 160, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 144, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, T.int8(-1), T.int8(-1), 12, p1[128], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 144, buffer9[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, T.int8(-1), T.int8(-1), 12, p2[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 144, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 112, T.int8(-1), T.int8(-1), 12, buffer9[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, T.int8(-1), T.int8(-1), 12, p2[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -75,34 +64,22 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) - buffer_encoded_1 = T.buffer_decl([160], dtype="uint8") - buffer_encoded_1_1 = T.buffer_decl([32], dtype="uint8") - buffer_encoded_2_1 = T.buffer_decl([160], dtype="uint8") - buffer_encoded_3_1 = T.buffer_decl([32], dtype="uint8") - buffer_encoded_4_1 = T.buffer_decl([176], dtype="uint8") - buffer_encoded_5_1 = T.buffer_decl([32], dtype="uint8") - buffer_encoded_6_1 = T.buffer_decl([160], dtype="uint8") - buffer_encoded_7_1 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_1 = T.buffer_decl([192], dtype="uint8") + buffer_encoded_2_1 = T.buffer_decl([192], dtype="uint8") + buffer_encoded_4_1 = T.buffer_decl([208], dtype="uint8") + buffer_encoded_6_1 = T.buffer_decl([192], dtype="uint8") # body - placeholder_global = T.allocate([176], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_global_2 = T.allocate([160], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global_2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_global_1 = T.buffer_decl([160], dtype="uint8", data=placeholder_global.data) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 160, placeholder_global_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1_1[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2_1[0], 160, placeholder_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3_1[0], 32, placeholder_d_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 80, placeholder_global_1[80], 80, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4_1[0], 176, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5_1[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 80, placeholder_global_2[80], 80, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6_1[0], 160, placeholder_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7_1[0], 32, placeholder_d_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 96, placeholder_global[96], 80, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 80, placeholder_global_2[80], 80, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + p1 = T.allocate([208], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([192], "uint8", "global", annotations={"disable_lower_builtin":True}) + p3 = T.buffer_decl([192], dtype="uint8", data=p1.data) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 192, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2_1[0], 192, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 80, p3[80], 80, 12, p3[160], 16, p3[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4_1[0], 208, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, p2[80], 80, 12, p2[160], 16, p2[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6_1[0], 192, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 96, p1[96], 80, 12, p1[176], 16, p1[192], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, p2[80], 80, 12, p2[160], 16, p2[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -113,12 +90,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), ( "ethos-u55-128", WeightStreamOnlyU55, - [128, 32, 112, 32, 112, 32, 112, 32], + [160, 144, 144, 144], ), ( "ethos-u65-512", WeightStreamOnlyU65, - [160, 32, 160, 32, 176, 32, 160, 32], + [192, 192, 208, 192], ), ], ) @@ -160,7 +137,7 @@ def _get_func(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + assert reference_const_sizes.sort() == test_const_size.sort() # fmt: off @@ -170,21 +147,14 @@ class RereadWeightsU55: def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer1 = T.buffer_decl([304], "uint8") - buffer2 = T.buffer_decl([80], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) + buffer1 = T.buffer_decl([384], "uint8") # body - p1 = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - p3 = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True}) - p4 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 304, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 304, p3[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 304, T.int8(-1), T.int8(-1), 12, p2[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 304, T.int8(-1), T.int8(-1), 12, p4[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + p1 = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 304, T.int8(-1), T.int8(-1), 12, p1[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 304, T.int8(-1), T.int8(-1), 12, p2[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -195,21 +165,14 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) - placeholder_encoded_1 = T.buffer_decl([368], "uint8") - placeholder_encoded_1_2 = T.buffer_decl([96], "uint8") + placeholder_encoded_1 = T.buffer_decl([464], "uint8") # body - placeholder_global = T.allocate([368], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_global_1 = T.allocate([368], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global_1 = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 368, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1_2[0], 96, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 368, placeholder_global_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1_2[0], 96, placeholder_d_global_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 192, placeholder_global[192], 176, 12, placeholder_d_global[0], 48, placeholder_d_global[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 192, placeholder_global_1[192], 176, 12, placeholder_d_global_1[0], 48, placeholder_d_global_1[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + p1 = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -221,12 +184,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), ( "ethos-u55-128", RereadWeightsU55, - [304, 80], + [384], ), ( "ethos-u65-512", RereadWeightsU65, - [368, 96], + [464], ), ], ) @@ -268,7 +231,7 @@ def _get_func(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + assert reference_const_sizes.sort() == test_const_size.sort() # fmt: off @@ -282,8 +245,6 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), buffer_1 = T.buffer_decl([160], "uint8") buffer_2 = T.buffer_decl([160], "uint8") buffer_3 = T.buffer_decl([80], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -302,8 +263,6 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), placeholder_encoded_1 = T.buffer_decl([160], dtype="uint8") placeholder_encoded_2 = T.buffer_decl([208], dtype="uint8") placeholder_encoded_3 = T.buffer_decl([96], dtype="uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) # body ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -364,87 +323,64 @@ def _get_func(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + assert reference_const_sizes.sort() == test_const_size.sort() # fmt: off @tvm.script.ir_module class MixedReadU55: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(112,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer1 = T.buffer_decl([80], "uint8") - buffer2 = T.buffer_decl([32], "uint8") - buffer3 = T.buffer_decl([80], "uint8") - buffer4 = T.buffer_decl([32], "uint8") - buffer5 = T.buffer_decl([80], "uint8") - buffer6 = T.buffer_decl([32], "uint8") - buffer7 = T.buffer_decl([80], "uint8") - buffer8 = T.buffer_decl([32], "uint8") + buffer1 = T.buffer_decl([112], "uint8") + buffer3 = T.buffer_decl([112], "uint8") + buffer5 = T.buffer_decl([112], "uint8") buffer9 = T.buffer_decl([592], "uint8") buffer10 = T.buffer_decl([160], "uint8") - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) + buffer11 = T.buffer_decl([2048], "int8") # body - p1 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) p3 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - p4 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - p5 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 80, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 32, p2[0], dtype="handle")) + p2 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 112, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 80, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 32, p5[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 80, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 32, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 80, T.int8(-1), T.int8(-1), 12, p5[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 80, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 32, p5[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 80, T.int8(-1), T.int8(-1), 12, p5[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 112, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 112, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class MixedReadU65: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) - T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) + # buffer definition - buffer_encoded_1 = T.buffer_decl([96], dtype="uint8") - buffer_encoded_1_2 = T.buffer_decl([32], dtype="uint8") - placeholder_encoded_1 = T.buffer_decl([608], dtype="uint8") - placeholder_encoded_1_2 = T.buffer_decl([160], dtype="uint8") - buffer_encoded_2_1 = T.buffer_decl([96], dtype="uint8") - buffer_encoded_3_1 = T.buffer_decl([32], dtype="uint8") - buffer_encoded_4_1 = T.buffer_decl([96], dtype="uint8") - buffer_encoded_5_1 = T.buffer_decl([32], dtype="uint8") - buffer_encoded_6_1 = T.buffer_decl([96], dtype="uint8") - buffer_encoded_7_1 = T.buffer_decl([32], dtype="uint8") - placeholder_global = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - placeholder_global_2 = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global_2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 96, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1_2[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded_1[0], 304, placeholder_encoded_1[304], 304, 12, placeholder_encoded_1_2[0], 80, placeholder_encoded_1_2[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2_1[0], 96, placeholder_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3_1[0], 32, placeholder_d_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 48, placeholder_global[48], 48, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4_1[0], 96, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5_1[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 48, placeholder_global_2[48], 48, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6_1[0], 96, placeholder_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7_1[0], 32, placeholder_d_global_2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 48, placeholder_global[48], 48, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 48, placeholder_global_2[48], 48, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + buffer1 = T.buffer_decl([128], dtype="uint8") + buffer2 = T.buffer_decl([128], dtype="uint8") + buffer3 = T.buffer_decl([128], dtype="uint8") + buffer4 = T.buffer_decl([608], dtype="uint8") + buffer5 = T.buffer_decl([160], dtype="uint8") + buffer6 = T.buffer_decl([2048], dtype="int8") + p1 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + p3 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 128, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -455,12 +391,12 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), ( "ethos-u55-128", MixedReadU55, - [592, 160, 80, 32, 80, 32, 80, 32, 80, 32], + [592, 160, 112, 112, 112, 112], ), ( "ethos-u65-512", MixedReadU65, - [608, 160, 96, 32, 96, 32, 96, 32, 96, 32], + [608, 160, 128, 128, 128, 128], ), ], ) @@ -512,7 +448,7 @@ def _get_func(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + assert reference_const_sizes.sort() == test_const_size.sort() def test_constant_as_input(): @@ -543,7 +479,7 @@ def get_graph(): # Check tile address for the scalar constant input hasn't been # overwritten. - extern_calls = tir_mod["main"].body.body.body.body.body + extern_calls = tir_mod["main"].body.body.body.body binary_elementwise = extern_calls[-1].value args = binary_elementwise.args diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py new file mode 100644 index 000000000000..caf09abdb020 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -0,0 +1,561 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import tvm +from tvm.script import tir as T +from tvm.relay.backend.contrib.ethosu.tir.passes import MergeConstants +import numpy as np + + +def check_const_dictionaries(const_dict, new_const_dict): + assert list(const_dict) == list(new_const_dict) + for key, value in const_dict.items(): + new_value = new_const_dict[key] + assert len(value) == len(new_value) + for i in range(len(value)): + assert value[i] == new_value[i] + + +def test_only_one_operator(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p1 = T.allocate([128], "uint8", "global") + p4 = T.allocate([32], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer2: T.Buffer[(160,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p4 = T.allocate([160], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + const_dict = { + 0: np.array([0, 10], dtype=np.uint8), + 1: np.array([1, 11], dtype=np.uint8), + } + new_const_dict = {0: np.concatenate((const_dict[0], const_dict[1]))} + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_all_operators_with_weights(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"], buffer4: T.Buffer[(112,), "uint8"], buffer5: T.Buffer[(32,), "uint8"], buffer6: T.Buffer[(112,), "uint8"], buffer7: T.Buffer[(32,), "uint8"], buffer8: T.Buffer[(112,), "uint8"], buffer9: T.Buffer[(32,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p1 = T.allocate([128], "uint8", "global") + p2 = T.allocate([112], "uint8", "global") + p3 = T.allocate([112], "uint8", "global") + p4 = T.allocate([32], "uint8", "global") + p5 = T.allocate([32], "uint8", "global") + p6 = T.allocate([32], "uint8", "global") + p7 = T.allocate([112], "uint8", "global") + p8 = T.allocate([3], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 112, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 32, p5[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 112, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 32, p6[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, 12, p5[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 112, p7[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer9[0], 32, p8[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, 12, p6[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p8[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer2: T.Buffer[(160,), "uint8"], buffer4: T.Buffer[(144,), "uint8"], buffer6: T.Buffer[(144,), "uint8"], buffer8: T.Buffer[(144,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p4 = T.allocate([160], "uint8", "global") + p7 = T.allocate([144], "uint8", "global") + p10 = T.allocate([144], "uint8", "global") + p11 = T.allocate([144], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 144, p7[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 144, p10[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p7[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 144, p11[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p10[0], 112, 12, p10[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p11[0], 112, 12, p11[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = { + 0: np.array([0], dtype=np.uint8), + 1: np.array([1], dtype=np.uint8), + 2: np.array([2], dtype=np.uint8), + 3: np.array([3], dtype=np.uint8), + 4: np.array([4], dtype=np.uint8), + 5: np.array([5], dtype=np.uint8), + 6: np.array([6], dtype=np.uint8), + 7: np.array([7], dtype=np.uint8), + } + new_const_dict = { + 0: np.concatenate((const_dict[0], const_dict[1])), + 1: np.concatenate((const_dict[2], const_dict[3])), + 2: np.concatenate((const_dict[4], const_dict[5])), + 3: np.concatenate((const_dict[6], const_dict[7])), + } + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_operators_with_and_without_weights(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer2: T.Buffer[(80,), "uint8"], buffer3: T.Buffer[(64,), "uint8"]) -> None: + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer0 = T.buffer_decl([390336], "int8") + buffer1 = T.buffer_decl([97156], "int8") + buffer6 = T.buffer_decl([390336], "int8") + # body + p2 = T.allocate([80], "uint8", "global") + p3 = T.allocate([64], "uint8", "global") + T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, buffer6[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p2[0], 80, 0, p3[0], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer2: T.Buffer[(144,), "uint8"]) -> None: + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer0 = T.buffer_decl([390336], "int8") + buffer1 = T.buffer_decl([97156], "int8") + buffer6 = T.buffer_decl([390336], "int8") + # body + p3 = T.allocate([144], "uint8", "global") + T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 144, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, buffer6[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p3[0], 80, 0, p3[80], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = { + 0: np.array([0], dtype=np.uint8), + 1: np.array([1], dtype=np.uint8), + } + new_const_dict = {0: np.concatenate((const_dict[0], const_dict[1]))} + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_copy_to_buffer_with_local_scope(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer1: T.Buffer[(64,), "uint8"], + buffer2: T.Buffer[(48,), "uint8"], + buffer3: T.Buffer[(256,), "uint8"], + buffer4: T.Buffer[(256,), "uint8"], + buffer5: T.Buffer[(16,), "uint8"], + buffer6: T.Buffer[(48,), "uint8"], + buffer7: T.Buffer[(256,), "uint8"], + buffer8: T.Buffer[(64,), "uint8"], + buffer9: T.Buffer[(256,), "int8"], + ) -> None: + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # body + p1 = T.allocate([48], "uint8", "global") + p2 = T.allocate([48], "uint8", "global") + p3 = T.allocate([256], "int8", "local") + p5 = T.allocate([16], "uint8", "global") + p6 = T.allocate([48], "uint8", "global") + p7 = T.allocate([256], "int8", "local") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 48, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 48, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 256, p3[0], dtype="handle")) # Local + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 16, p5[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 48, p6[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer1[0], 0, 0, 0, T.float32(0.00392081), -128, "NHWC", 16, 4, 1, "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.00839574), -128, "NHCWB16", 64, 16, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, 0, p2[0], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 256, p7[0], dtype="handle")) # Local + T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.0078125), 0, "NHCWB16", 64, 16, 1, "int8", 4, 4, 4, 4, 0, 4, buffer8[0], 0, 0, 0, T.float32(0.00372155), -128, "NHWC", 16, 4, 1, 1, 1, 1, 1, 1, 1, p5[0], 16, 0, p6[0], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer1: T.Buffer[(64,), "uint8"], + buffer2: T.Buffer[(96,), "uint8"], + buffer4: T.Buffer[(256,), "uint8"], + buffer5: T.Buffer[(64,), "uint8"], + buffer7: T.Buffer[(256,), "uint8"], + buffer8: T.Buffer[(64,), "uint8"], + buffer9: T.Buffer[(256,), "int8"], + ) -> None: + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # body + p1 = T.allocate([96], "uint8", "global") + p2 = T.allocate([64], "uint8", "global") + p3 = T.allocate([256], "int8", "local") + p7 = T.allocate([256], "int8", "local") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 256, p3[0], dtype="handle")) # Local + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 64, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer1[0], 0, 0, 0, T.float32(0.00392081), -128, "NHWC", 16, 4, 1, "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.00839574), -128, "NHCWB16", 64, 16, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, 0, p1[48], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 256, p7[0], dtype="handle")) # Local + T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.0078125), 0, "NHCWB16", 64, 16, 1, "int8", 4, 4, 4, 4, 0, 4, buffer8[0], 0, 0, 0, T.float32(0.00372155), -128, "NHWC", 16, 4, 1, 1, 1, 1, 1, 1, 1, p2[0], 16, 0, p2[16], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = { + 1: np.array([1], dtype=np.uint8), + 2: np.array([2], dtype=np.uint8), + 3: np.array([3], dtype=np.uint8), + 4: np.array([4], dtype=np.uint8), + 5: np.array([5], dtype=np.uint8), + 6: np.array([6], dtype=np.uint8), + } + new_const_dict = { + 1: np.concatenate((const_dict[1], const_dict[2])), + 2: const_dict[3], + 3: np.concatenate((const_dict[4], const_dict[5])), + 4: const_dict[6], + } + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_no_copies(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main() -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder = T.buffer_decl([20], "int8") + ethosu_write = T.buffer_decl([16], "int8") + # body + ethosu_write_4 = T.allocate([16], "int8", "global") + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 1, 4, 4, 1, 0, 4, placeholder[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "int8", 1, 4, 1, 1, 0, 4, placeholder[16], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 1, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "MAX", 0, "CLIP", -128, 127, "TFL", 1, 4, 4, dtype="handle")) + T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main() -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder = T.buffer_decl([20], "int8") + ethosu_write = T.buffer_decl([16], "int8") + # body + ethosu_write_4 = T.allocate([16], "int8", "global") + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 1, 4, 4, 1, 0, 4, placeholder[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "int8", 1, 4, 1, 1, 0, 4, placeholder[16], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 1, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "MAX", 0, "CLIP", -128, 127, "TFL", 1, 4, 4, dtype="handle")) + T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = {} + new_const_dict = {} + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_copies_to_the_same_buffer(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p1 = T.allocate([128], "uint8", "global") + p4 = T.allocate([32], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer2: T.Buffer[(160,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p5 = T.allocate([160], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p5[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5[0], 128, 12, p5[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p5[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5[0], 128, 12, p5[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = { + 0: np.array([0], dtype=np.uint8), + 1: np.array([1], dtype=np.uint8), + } + new_const_dict = {0: np.concatenate((const_dict[0], const_dict[1]))} + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_read_from_the_same_buffer(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + p1 = T.allocate([368], "uint8", "global") + p2 = T.allocate([96], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # body + p1 = T.allocate([464], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + # fmt: on + + const_dict = { + 1: np.array([1], dtype=np.uint8), + 2: np.array([2], dtype=np.uint8), + } + new_const_dict = {1: np.concatenate((const_dict[1], const_dict[2]))} + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_cycle_count(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"], buffer4: T.Buffer[(112,), "uint8"], buffer5: T.Buffer[(32,), "uint8"], buffer6: T.Buffer[(112,), "uint8"], buffer7: T.Buffer[(32,), "uint8"], buffer8: T.Buffer[(112,), "uint8"], buffer9: T.Buffer[(32,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + v1a = T.var("int32") + v1b = T.var("int32") + v1c = T.var("int32") + v2a = T.var("int32") + v2b = T.var("int32") + v2c = T.var("int32") + v3a = T.var("int32") + v3b = T.var("int32") + v3c = T.var("int32") + v4a = T.var("int32") + v4b = T.var("int32") + v4c = T.var("int32") + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p1 = T.allocate([128], "uint8", "global") + p2 = T.allocate([112], "uint8", "global") + p3 = T.allocate([112], "uint8", "global") + p4 = T.allocate([32], "uint8", "global") + p5 = T.allocate([32], "uint8", "global") + p6 = T.allocate([32], "uint8", "global") + p7 = T.allocate([112], "uint8", "global") + p8 = T.allocate([3], "uint8", "global") + with T.attr(T.iter_var(v1a, None, "DataPar", ""), "pragma_compute_cycles_hint", 100): + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle")) + with T.attr(T.iter_var(v1b, None, "DataPar", ""), "pragma_compute_cycles_hint", 101): + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle")) + with T.attr(T.iter_var(v2a, None, "DataPar", ""), "pragma_compute_cycles_hint", 102): + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 112, p2[0], dtype="handle")) + with T.attr(T.iter_var(v2b, None, "DataPar", ""), "pragma_compute_cycles_hint", 103): + T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 32, p5[0], dtype="handle")) + with T.attr(T.iter_var(v1c, None, "DataPar", ""), "pragma_compute_cycles_hint", 300): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v3a, None, "DataPar", ""), "pragma_compute_cycles_hint", 104): + T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 112, p3[0], dtype="handle")) + with T.attr(T.iter_var(v3b, None, "DataPar", ""), "pragma_compute_cycles_hint", 105): + T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 32, p6[0], dtype="handle")) + with T.attr(T.iter_var(v2c, None, "DataPar", ""), "pragma_compute_cycles_hint", 301): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, 12, p5[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v4a, None, "DataPar", ""), "pragma_compute_cycles_hint", 106): + T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 112, p7[0], dtype="handle")) + with T.attr(T.iter_var(v4b, None, "DataPar", ""), "pragma_compute_cycles_hint", 107): + T.evaluate(T.call_extern("ethosu_copy", buffer9[0], 32, p8[0], dtype="handle")) + with T.attr(T.iter_var(v3c, None, "DataPar", ""), "pragma_compute_cycles_hint", 302): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, 12, p6[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v4c, None, "DataPar", ""), "pragma_compute_cycles_hint", 303): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p8[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(buffer2: T.Buffer[(160,), "uint8"], buffer4: T.Buffer[(144,), "uint8"], buffer6: T.Buffer[(144,), "uint8"], buffer8: T.Buffer[(144,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + v1a = T.var("int32") + v1c = T.var("int32") + v2a = T.var("int32") + v2c = T.var("int32") + v3a = T.var("int32") + v3c = T.var("int32") + v4a = T.var("int32") + v4c = T.var("int32") + buffer1 = T.buffer_decl([8192], "int8") + buffer10 = T.buffer_decl([2048], "int8") + # body + p4 = T.allocate([160], "uint8", "global") + p7 = T.allocate([144], "uint8", "global") + p10 = T.allocate([144], "uint8", "global") + p11 = T.allocate([144], "uint8", "global") + with T.attr(T.iter_var(v1a, None, "DataPar", ""), "pragma_compute_cycles_hint", 201): + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle")) + with T.attr(T.iter_var(v2a, None, "DataPar", ""), "pragma_compute_cycles_hint", 205): + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 144, p7[0], dtype="handle")) + with T.attr(T.iter_var(v1c, None, "DataPar", ""), "pragma_compute_cycles_hint", 300): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v3a, None, "DataPar", ""), "pragma_compute_cycles_hint", 209): + T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 144, p10[0], dtype="handle")) + with T.attr(T.iter_var(v2c, None, "DataPar", ""), "pragma_compute_cycles_hint", 301): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p7[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v4a, None, "DataPar", ""), "pragma_compute_cycles_hint", 213): + T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 144, p11[0], dtype="handle")) + with T.attr(T.iter_var(v3c, None, "DataPar", ""), "pragma_compute_cycles_hint", 302): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p10[0], 112, 12, p10[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + with T.attr(T.iter_var(v4c, None, "DataPar", ""), "pragma_compute_cycles_hint", 303): + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p11[0], 112, 12, p11[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + # fmt: on + + const_dict = { + 0: np.array([0], dtype=np.uint8), + 1: np.array([1], dtype=np.uint8), + 2: np.array([2], dtype=np.uint8), + 3: np.array([3], dtype=np.uint8), + 4: np.array([4], dtype=np.uint8), + 5: np.array([5], dtype=np.uint8), + 6: np.array([6], dtype=np.uint8), + 7: np.array([7], dtype=np.uint8), + } + new_const_dict = { + 0: np.concatenate((const_dict[0], const_dict[1])), + 1: np.concatenate((const_dict[2], const_dict[3])), + 2: np.concatenate((const_dict[4], const_dict[5])), + 3: np.concatenate((const_dict[6], const_dict[7])), + } + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_multiple_prim_funcs(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(): + T.evaluate(0) + + @T.prim_func + def abc(): + T.evaluate(0) + # fmt: on + + err_rgx = ( + r"Expected a single primitive function called 'main'. " + r"Please run the MergeConstants pass in conjunction with the LowerToTIR\(\) pass." + ) + with pytest.raises(tvm.TVMError, match=err_rgx): + MergeConstants({})(InputModule) + + +def test_no_main_prim_func(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def abs(): + T.evaluate(0) + # fmt: on + + err_rgx = ( + r"Expected a single primitive function called 'main'. " + r"Please run the MergeConstants pass in conjunction with the LowerToTIR\(\) pass." + ) + with pytest.raises(tvm.TVMError, match=err_rgx): + MergeConstants({})(InputModule) diff --git a/tests/python/contrib/test_ethosu/test_networks.py b/tests/python/contrib/test_ethosu/test_networks.py index 075565cd92a6..02643f6c1ded 100644 --- a/tests/python/contrib/test_ethosu/test_networks.py +++ b/tests/python/contrib/test_ethosu/test_networks.py @@ -44,13 +44,13 @@ @pytest.mark.parametrize( "accel_type, model_url, workspace_size", [ - ("ethos-u65-256", MOBILENET_V1_URL, 1892704), - ("ethos-u65-256", MOBILENET_V2_URL, 2257984), - ("ethos-u55-256", MOBILENET_V1_URL, 1892704), - ("ethos-u55-256", MOBILENET_V2_URL, 2257984), - ("ethos-u55-128", MOBILENET_V2_URL, 2257984), - ("ethos-u55-64", MOBILENET_V2_URL, 2257984), - ("ethos-u55-32", MOBILENET_V2_URL, 2258000), + ("ethos-u65-256", MOBILENET_V1_URL, 1793376), + ("ethos-u65-256", MOBILENET_V2_URL, 2218160), + ("ethos-u55-256", MOBILENET_V1_URL, 1793376), + ("ethos-u55-256", MOBILENET_V2_URL, 2218160), + ("ethos-u55-128", MOBILENET_V2_URL, 2218160), + ("ethos-u55-64", MOBILENET_V2_URL, 2218160), + ("ethos-u55-32", MOBILENET_V2_URL, 2218160), ], ) def test_networks_without_usmp(accel_type, model_url, workspace_size): diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index cc996e59412c..d2c759a0ae4d 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -41,9 +41,6 @@ def main(placeholder: T.Buffer[(1536,), "int8"], placeholder_1: T.Buffer[(1280,) buffer_5 = T.buffer_decl([160], "uint8") buffer_6 = T.buffer_decl([2992], "uint8") buffer_7 = T.buffer_decl([160], "uint8") - T.preflattened_buffer(placeholder, [1, 8, 12, 16], "int8", data=placeholder.data) - T.preflattened_buffer(placeholder_1, [1, 8, 10, 16], "int8", data=placeholder_1.data) - T.preflattened_buffer(T_concat, [1, 8, 32, 16], "int8", data=T_concat.data) # body T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 63f9fc44c778..46a3c5a15bf5 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -373,8 +373,6 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, buffer_1 = T.buffer_decl([80], "uint8") buffer_2 = T.buffer_decl([320], "uint8") buffer_3 = T.buffer_decl([160], "uint8") - T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -394,8 +392,6 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, buffer_1 = T.buffer_decl([320], "uint8") buffer_2 = T.buffer_decl([1312], "uint8") buffer_3 = T.buffer_decl([2608], "uint8") - T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -415,8 +411,6 @@ def main(placeholder_5: T.Buffer[(768,), "int8"], ethosu_write_1: T.Buffer[(640, buffer_1 = T.buffer_decl([80], "uint8") buffer_2 = T.buffer_decl([320], "uint8") buffer_3 = T.buffer_decl([880], "uint8") - T.preflattened_buffer(placeholder_5, [1, 16, 16, 3], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 20, 4, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -438,8 +432,6 @@ def main(placeholder_5: T.Buffer[(1024,), "int8"], ethosu_write_1: T.Buffer[(204 buffer_1 = T.buffer_decl([352], "uint8") buffer_2 = T.buffer_decl([272], "uint8") buffer_3 = T.buffer_decl([11040], "uint8") - T.preflattened_buffer(placeholder_5, [1, 8, 1, 8, 16], 'int8', data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 2, 8, 16], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -459,8 +451,6 @@ def main(placeholder: T.Buffer[(192,), "int8"], ethosu_write: T.Buffer[(8192,), buffer_1 = T.buffer_decl([320], "uint8") buffer_2 = T.buffer_decl([304], "uint8") buffer_3 = T.buffer_decl([80], "uint8") - T.preflattened_buffer(placeholder, [1, 8, 8, 3], 'int8', data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 32, 32, 8], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) @@ -480,8 +470,6 @@ def main(placeholder: T.Buffer[(1024,), "int8"], ethosu_write: T.Buffer[(32768,) buffer_1 = T.buffer_decl([352], "uint8") buffer_2 = T.buffer_decl([11040], "uint8") buffer_3 = T.buffer_decl([272], "uint8") - T.preflattened_buffer(placeholder, [1, 8, 1, 8, 16], 'int8', data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 32, 2, 32, 16], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) @@ -641,8 +629,6 @@ def main(placeholder_3: T.Buffer[(960,), "int8"], ethosu_write_1: T.Buffer[(1024 T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([848], "uint8") buffer_1 = T.buffer_decl([160], "uint8") - T.preflattened_buffer(placeholder_3, [1, 10, 12, 8], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -656,8 +642,6 @@ def main(placeholder_3: T.Buffer[(315,), "int8"], ethosu_write_1: T.Buffer[(240, T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([656], "uint8") - T.preflattened_buffer(placeholder_3, [1, 7, 9, 5], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 3, 5, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -700,8 +684,6 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [4, 6, 8, 1], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -716,8 +698,6 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [1, 24, 8], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -732,8 +712,6 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [192, 1], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -748,8 +726,6 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.buffer_decl([160], "uint8") buffer_1 = T.buffer_decl([848], "uint8") - T.preflattened_buffer(placeholder_3, [192], 'int8', data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 932df71d2402..6b97b38d80e6 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -34,16 +34,11 @@ class ReferenceModule: def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([80], "uint8") - buffer_1 = T.buffer_decl([304], "uint8") - T.preflattened_buffer(placeholder_3, [1, 16, 16, 32], dtype="int8", data=placeholder_3.data) - T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 8], dtype="int8", data=ethosu_write_1.data) + buffer_1 = T.buffer_decl([384], "uint8") # body - placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 304, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer[0], 80, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + placeholder_global = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin": True}) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 384, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_global[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -80,23 +75,15 @@ class WeightStream: def main(placeholder_5: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(4096,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([416], "uint8") - buffer_1 = T.buffer_decl([112], "uint8") - buffer_2 = T.buffer_decl([272], "uint8") - buffer_3 = T.buffer_decl([64], "uint8") - T.preflattened_buffer(placeholder_5, [1, 16, 16, 32], dtype="int8", data=placeholder_5.data) - T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 16], dtype="int8", data=ethosu_write_1.data) + buffer = T.buffer_decl([528], "uint8") + buffer_2 = T.buffer_decl([336], "uint8") # body - placeholder_global_unrolled_iter_0 = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global_unrolled_iter_0 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_global_unrolled_iter_1 = T.allocate([272], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global_unrolled_iter_1 = T.allocate([64], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", buffer[0], 416, placeholder_global_unrolled_iter_0[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 112, placeholder_d_global_unrolled_iter_0[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 272, placeholder_global_unrolled_iter_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 64, placeholder_d_global_unrolled_iter_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_0[0], 416, T.int8(-1), T.int8(-1), 12, placeholder_d_global_unrolled_iter_0[0], 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_1[0], 272, T.int8(-1), T.int8(-1), 12, placeholder_d_global_unrolled_iter_1[0], 64, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + placeholder_d_global = T.allocate([528], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global_1 = T.allocate([336], "uint8", "global", annotations={"disable_lower_builtin": True}) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 528, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 336, placeholder_d_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_d_global[0], 416, T.int8(-1), T.int8(-1), 12, placeholder_d_global[416], 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_d_global_1[0], 272, T.int8(-1), T.int8(-1), 12, placeholder_d_global_1[272], 64, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 4baea26e591e..ba050de2b473 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -182,24 +182,16 @@ class DiamondGraphTir: @T.prim_func def main(placeholder: T.Buffer[(301056,), "int8"], ethosu_write: T.Buffer[(75264,), "int8"]) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(placeholder, [1, 56, 56, 96], dtype='int8', data=placeholder.data) - T.preflattened_buffer(ethosu_write, [1, 56, 56, 24], dtype='int8', data=ethosu_write.data) - buffer1 = T.buffer_decl([2608], "uint8") - buffer2 = T.buffer_decl([240], "uint8") - buffer3 = T.buffer_decl([736], "uint8") - buffer4 = T.buffer_decl([240], "uint8") - p1 = T.allocate([2608], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True}) - p3 = T.allocate([736], "uint8", "global", annotations={"disable_lower_builtin":True}) - p4 = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True}) + buffer1 = T.buffer_decl([2848], "uint8") + buffer3 = T.buffer_decl([976], "uint8") + p1 = T.allocate([2848], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.allocate([976], "uint8", "global", annotations={"disable_lower_builtin":True}) p5 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) p6 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 2608, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 240, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 736, p3[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 240, p4[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p1[0], 2608, T.int8(-1), T.int8(-1), 12, p2[0], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, p6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p3[0], 736, T.int8(-1), T.int8(-1), 12, p4[0], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 2848, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 976, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p1[0], 2608, T.int8(-1), T.int8(-1), 12, p1[2608], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, p6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p2[0], 736, T.int8(-1), T.int8(-1), 12, p2[736], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0,T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, p6[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on