From 7a381254afafd773cecdf6999e59decfbe8d9e8d Mon Sep 17 00:00:00 2001 From: "yunjing.lh" Date: Mon, 27 Apr 2020 18:48:54 +0800 Subject: [PATCH] tflite debug --- include/tvm/tir/ir_pass.h | 286 +++++++++++++++++++ src/printer/relay_graph_printer.cc | 56 ++-- src/relay/backend/build_module.cc | 5 +- src/relay/backend/compile_engine.cc | 4 - src/relay/backend/graph_runtime_codegen.cc | 6 - src/relay/qnn/op/requantize.cc | 10 +- src/relay/qnn/util.cc | 72 +++-- tests/python/frontend/tflite/test_forward.py | 25 +- 8 files changed, 384 insertions(+), 80 deletions(-) create mode 100644 include/tvm/tir/ir_pass.h diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h new file mode 100644 index 0000000000000..d128d83b03e70 --- /dev/null +++ b/include/tvm/tir/ir_pass.h @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/ir_pass.h + * \brief Collection of IR pass functions + * + * When the pass functions in this file are for Stmt, + * we can use PassFunction(Evaluate(expr)) to apply it to Expr + */ +#ifndef TVM_TIR_IR_PASS_H_ +#define TVM_TIR_IR_PASS_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + + +namespace tvm { +namespace tir { + +class TaskTimer { + public: + explicit TaskTimer(std::string task) : + task_(task), + tp_(std::chrono::steady_clock::now()) {} + + ~TaskTimer() { + auto tp2_ = std::chrono::steady_clock::now(); + std::chrono::duration elapsed_time = tp2_ - tp_; + if (elapsed_time.count() >= 10) { + // Only print non-negligible tasks + std::cout << task_ << " takes " + << elapsed_time.count() << "s" << std::endl; + } + } + + private: + std::string task_; + std::chrono::time_point tp_; +}; + +/*! + * \brief Simplify the expression. + * \param expr The expression to be simplifed. + * \param vrange The range information about the variable. + * \return Canonicalized statement. + */ +TVM_DLL PrimExpr Simplify(PrimExpr expr, Map vrange = Map()); + +/*! + * \brief Simplify the statement. + * \param stmt The statement to be simplifed. + * \param vrange The range information about the variable. + * \return Canonicalized statement. + */ +Stmt Simplify(Stmt stmt, Map vrange = Map()); + +/*! + * \brief Simplify by applying canonical form. + * \param stmt The statement to be canonically simplifed. + * \param vrange The range information about the variable. + * \return Canonicalized statement. + */ +Stmt CanonicalSimplify(Stmt stmt, + Map vrange = Map()); + +/*! + * \brief Simplify by applying canonical form. + * \param expr The statement to be canonically simplifed. + * \param vrange The range information about the variable. + * \return Canonicalized expression. + */ +TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr, + Map vrange = Map()); + +/*! + * \brief verifies whether the IR stmt or Expr is in SSA form. + * That is: each VarExpr is defined and assigned once(in Let/For) + * + * \param ir The root of the IR DAG. + * \return Whether IR is in SSA form. + * \note All the passes in this file uses SSA form and outputs SSA form. + */ +TVM_DLL bool VerifySSA(const Stmt& ir); + +/*! + * \brief Whether the expression have side effect. + * \return whether expression have side effect + */ +TVM_DLL bool HasSideEffect(const PrimExpr& e); + +/*! + * \brief Whether e expression used var. + * \param e The expression to be checked. + * \param v The variable. + * \return Whether e uses v. + */ +bool ExprUseVar(const PrimExpr& e, const Var& v); + +/*! + * \brief Whether e expression used any var in variable set.. + * \param e The expression to be checked. + * \param vset The variable set. + * \return Whether e uses vset. + */ +bool ExprUseVar(const PrimExpr& e, const std::unordered_set& vset); + +/*! + * \brief Convert a IR node to be SSA form. + * \param stmt The source statement to be converted. + * \return The converted form. + */ +TVM_DLL Stmt ConvertSSA(Stmt stmt); + +/*! + * \brief Substitute the var specified in key->var to be value. + * \param stmt The source statement to be substituted + * \param value_map The map of new values. + * \return The converted form. + */ +Stmt Substitute(Stmt stmt, + const std::unordered_map& value_map); + +/*! + * \brief Substitute the var specified in key->var to be value. + * \param expr The source expression to be substituted + * \param value_map The map of new values. + * \return The converted expression. + */ +PrimExpr Substitute(PrimExpr expr, + const std::unordered_map& value_map); + +/*! + * \brief Substitute the var specified in key->var to be value. + * \param stmt The source statement to be substituted + * \param value_map The map of new values. + * \return The converted form. + */ +Stmt Substitute(Stmt stmt, const Map& value_map); + +/*! + * \brief Substitute the var specified in key->var to be value. + * \param expr The source expression to be substituted + * \param value_map The map of new values. + * \return The converted expression. + */ +PrimExpr Substitute(PrimExpr expr, const Map& value_map); + +/*! + * \brief inline all calls of f in stmt. + * + * \param stmt The statement to apply inline optimization. + * \param f The function reference to be inlined + * \param args The arguments variable of the function. + * \param body The definition body of the function. + * \return The result stmt + * + * \note All the passes in this file uses SSA form and outputs SSA form. + */ +Stmt Inline(Stmt stmt, + FunctionRef f, + Array args, + PrimExpr body); + +/*! + * \brief Flatten the multi-dimensional read/write + * to single dimensional Load/Store + * + * \param stmt The stmt to be trasnformed. + * \param extern_buffer Map specifies external + * buffer assignment of input and outputs. + * \param cache_line_size The size of CPU cache line. + * \param create_bound_attribute Whether to create bound attributes. + * \return Transformed stmt. + */ +Stmt StorageFlatten(Stmt stmt, + Map extern_buffer, + int cache_line_size, + bool create_bound_attribute = false); + +/*! + * \brief Try to modify the AST to support TensorCore + * + * \param stmt The stmt to be trasnformed. + * \param schedule The original schedule. + * \param extern_buffer Map specifies external + * buffer assignment of input and outputs. + * \return Transformed stmt. + */ +Stmt RewriteForTensorCore(Stmt stmt, + te::Schedule schedule, + Map extern_buffer); + +/*! + * \brief Verify if there is any argument bound to compact buffer. + * + * \param stmt The stmt to be verified. + * \return true if there is any buffer_bind_scope attribute found, + * otherwise, false. + */ +bool VerifyCompactBuffer(Stmt stmt); + +/*! + * \brief Inject prefetch instructions into stmt. + * \param stmt The statement to be transformed. + * \return Transformed stmt. + */ +Stmt InjectPrefetch(Stmt stmt); + +/*! + * \brief Decorate the stmt with a device scope, this is helpful for + * hardware accelerator without thread blocks. + * + * \param stmt The stmt to be transformed + * \return Transformed stmt. + */ +Stmt DecorateDeviceScope(Stmt stmt); + +/*! + * \brief Loop invariant code motion which locates and hoists if statements. + * \param stmt The stmt to do if statement hoisting. + * \return Transformed stmt. + */ +Stmt HoistIfThenElse(Stmt stmt); + +/*! + * \brief Rewrite the pointer content type of arguments, + * as well as Alloc internal to the function to use + * the most frequently accessed type for load/store + * to avoid pointer casting in backend when possible. + * + * \note implemeneted in storage_rewrite.cc + * \param f The function to be trasnformed + * \return Transformed function. + */ +PrimFunc PointerValueTypeRewrite(PrimFunc f); + +/*! + * \brief Verify the correctness of a GPU code + * It will check the whether the amount of memory usage or the number of threads + * in a block exceeds the limit + * \param stmt The statement to be checked + * \param constraints The dict to specify constraints to check. + * Possible keys are + * + * "max_local_memory_per_block": Total amount of local memory per block (in bytes). + * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). + * "max_threads_per_block": Maximum number of threads per block. + * "max_thread_x": Maximum length of threadIdx.x. + * "max_thread_y": Maximum length of threadIdx.y. + * "max_thread_z": Maximum length of threadIdx.z. + * + * If one key is missing in this argument, the pass won't check for that item. + * \return valid Whether it is a valid GPU code + * + */ +bool VerifyGPUCode(Stmt stmt, + Map constraints); + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_IR_PASS_H_ diff --git a/src/printer/relay_graph_printer.cc b/src/printer/relay_graph_printer.cc index b1d4e0778cfbe..946159789cfc3 100644 --- a/src/printer/relay_graph_printer.cc +++ b/src/printer/relay_graph_printer.cc @@ -43,16 +43,16 @@ class RelayGraphPrinter : public: explicit RelayGraphPrinter() {} - // Doc TempVar(int n) { - // Doc doc; - // return doc << "%" << n; - // } - Doc AllocTemp() { Doc doc; return doc << temp_var_counter_++; } + Doc AllocTempFunc() { + Doc doc; + return doc << temp_func_counter_++; + } + Doc PrintFinal(const ObjectRef& node) { if (node->IsInstance() && !node->IsInstance()) { @@ -63,11 +63,11 @@ class RelayGraphPrinter : } Doc graph_prop; - graph_prop << Doc::NewLine() << "node [shape=box;fontsize=9]"; + graph_prop << "node [shape=box;fontsize=9]" << Doc::NewLine() + << PrintScope(node) << Doc::NewLine(); Doc doc; - doc << "digraph relay {"; + doc << "digraph relay {" << Doc::NewLine(); doc << Doc::Indent(2, graph_prop); - doc << PrintScope(node); doc << "}"; return doc; } @@ -105,7 +105,6 @@ class RelayGraphPrinter : } else if (node.as()) { std::cout << "print pattern node ignored" << std::endl; } else if (node.as()) { - std::cout << "printing ir module" << std::endl; return PrintMod(Downcast(node)); } else { std::cout << "printing raw node" << std::endl; @@ -275,9 +274,7 @@ class RelayGraphPrinter : return Doc::Text(unique_prefix); } - - - Doc PrintFunc(const Doc& prefix, const relay::Function& fn) { + Doc PrintFunc(Doc prefix, const relay::Function& fn) { Doc param_doc; std::vector params; @@ -287,32 +284,34 @@ class RelayGraphPrinter : } Doc out_node; - out_node << prefix << "_out"; + Doc fn_node; + if (prefix.str() == "fn ") { + fn_node << "fn_" << AllocTempFunc(); + out_node << fn_node << "_out"; + } else { + fn_node << prefix; + out_node << prefix << "_out"; + } + memo_[fn] = out_node; param_doc << ConstructParamNode(out_node, Print(fn->ret_type)); - // param_doc << prefix << "_out [fontcolor=red;fontsize=8;label=\"" - // << prefix << "_out\\n" - // << Print(fn->ret_type) << "\"]" << Doc::NewLine(); Doc body_doc; body_doc << Doc::NewLine() - << "style=filled;" << Doc::NewLine() - << "color=lightgrey;" << Doc::NewLine() - << "label = \"" << prefix << "\"" << Doc::NewLine() + << "label = \"" << fn_node << "\"" << Doc::NewLine() << param_doc; body_doc << PrintScope(fn->body); Doc doc; - doc << Doc::NewLine() - << "subgraph cluster_func_" << prefix << " {" + doc << "subgraph cluster_func_" << fn_node << " {" << Doc::Indent(2, body_doc) - << memo_[fn->body] << "->" << prefix << "_out" << Doc::NewLine() - << "}"; + << memo_[fn->body] << "->" << out_node << Doc::NewLine() + << "}" << Doc::NewLine(); return doc; } - Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) { + Doc PrintFunc(Doc prefix, const BaseFunc& base_func) { if (auto* n = base_func.as()) { return PrintFunc(prefix, GetRef(n)); } else if (auto* n = base_func.as()) { @@ -420,12 +419,18 @@ class RelayGraphPrinter : } Doc VisitExpr_(const CallNode* op) final { + Doc doc; Expr op_expr = GetRef(op); Doc op_doc; const auto* cons_node = op->op.as(); + const auto* func_node = op->op.as(); if (cons_node) { op_doc << Doc::Text(cons_node->name_hint); + } else if (func_node) { + doc << VisitExpr(op->op); + op_doc << memo_[op->op]; + memo_[op_expr] = op_doc; } else { op_doc << VisitExpr(op->op); } @@ -437,7 +442,6 @@ class RelayGraphPrinter : label_doc << Print(op->checked_type()); memo_label_[op_expr] = label_doc; - Doc doc; doc << ConstructOpNode(memo_[op_expr], op_doc, memo_label_[op_expr]); // visit args first so they are lifted before the op // this places op closer to its call site @@ -761,7 +765,7 @@ class RelayGraphPrinter : /*! \brief meta data context */ TextMetaDataContext meta_; /*! \brief counter of temporary variable */ - size_t temp_var_counter_{0}; + size_t temp_var_counter_{0}, temp_func_counter_{0}; /*! \brief whether the printer is currently in an ADT definition */ bool in_adt_def_; /*! \brief arena for dependency graph */ diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 1db832739000e..6dd84ecbd3e9b 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -438,18 +438,17 @@ class RelayBuildModule : public runtime::ModuleNode { void BuildRelay( IRModule relay_module, const std::unordered_map& params) { + tvm::tir::TaskTimer tt("relay.build"); // Relay IRModule -> IRModule optimizations. relay_module = Optimize(relay_module, targets_, params); // Get the updated function. auto func = Downcast(relay_module->Lookup("main")); - std::cout << "Optimize done" << std::endl; + std::cout << AsText(func, false) << std::endl; // Generate code for the updated function. graph_codegen_ = std::unique_ptr(new GraphCodegen()); graph_codegen_->Init(nullptr, targets_); - std::cout << "GraphCodegen init done" << std::endl; graph_codegen_->Codegen(func); - std::cout << "GraphCodegen done" << std::endl; ret_.graph_json = graph_codegen_->GetJSON(); ret_.params = graph_codegen_->GetParams(); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 0eb408d3f482d..ce0a314f265b2 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -708,20 +708,16 @@ class CompileEngineImpl : public CompileEngineNode { for (te::Tensor arg : cache_node->outputs) { all_args.push_back(arg); } - std::cout << "lowering " << cache_node->func_name << std::endl; // lower the function if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { - std::cout << "using global lower" << std::endl; cache_node->funcs = (*f)( cfunc->schedule, all_args, cache_node->func_name, key->source_func); } else { - std::cout << "using default lower" << std::endl; tvm::BuildConfig bcfg = BuildConfig::Create(); std::unordered_map binds; cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds, bcfg); } - std::cout << "lower done" << std::endl; value->cached_func = CachedFunc(cache_node); return value; } diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index c332a06e33453..2a0d7981a17d4 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -357,7 +357,6 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator VisitExpr_(const CallNode* op) override { Expr expr = GetRef(op); - std::cout << "adding call node \n" << AsText(op->op, false) << std::endl; Function func; if (op->op.as()) { LOG(FATAL) << "Operators should be transformed away; try applying" @@ -390,7 +389,6 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorvalue; // Normal Relay Function - std::cout << "target size: " << targets_.size() << std::endl; if (targets_.size() == 1) { // homogeneous execution. const auto& it = targets_.begin(); @@ -409,16 +407,12 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorstr())) { lowered_funcs_[target->str()] = IRModule::Empty(); } lowered_funcs_[target->str()]->Update(lowered_func->funcs); - std::cout << "added call node, funcs: " << lowered_funcs_.size() << std::endl; return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name); diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 8a48c8e079d4a..235d4a11a8893 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -138,7 +138,15 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point, const Expr& output_scale, const Expr& output_zero_point, const RequantizeAttrs* param, const Array& input_shape, const DataType& out_dtype) { - auto tensor = Cast(input_tensor, DataType::Int(32)); + bool input_is_int32 = false; + if ((input_tensor.as() || input_tensor.as()) && + input_tensor->checked_type_.defined()) { + auto tensor_type = input_tensor->checked_type().as(); + if (tensor_type && tensor_type->dtype == DataType::Int(32)) + input_is_int32 = true; + } + auto tensor = input_is_int32 ? input_tensor : Cast(input_tensor, DataType::Int(32)); + // auto tensor = Cast(input_tensor, DataType::Int(32)); // 1) Subtract the input_zero_point auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index f61b0b5f0b056..d5c2951b6d018 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -37,30 +37,38 @@ namespace qnn { Expr SaturatingRoundingDoublingHigh32(const Expr& input_tensor, const Expr& multiplier_expr, const Expr& scaled_tensor, - const Array& input_shape) { + const Array& input_shape, + bool possible_to_overflow = true) { DataType hp_dtype = DataType::Int(64); - int32_t pos_nudge_value = (1ll << 30); - int32_t neg_nudge_value = 1 - (1ll << 30); + DataType lp_dtype = DataType::Int(32); + int64_t pos_nudge_value = (1ll << 30); + int64_t neg_nudge_value = 1 - (1ll << 30); auto pos_nudge = MakeConstantScalar(hp_dtype, pos_nudge_value); auto neg_nudge = MakeConstantScalar(hp_dtype, neg_nudge_value); auto pos_nudge_t = Full(pos_nudge, input_shape, hp_dtype); auto neg_nudge_t = Full(neg_nudge, input_shape, hp_dtype); - auto int32_min = MakeConstantScalar( - hp_dtype, std::numeric_limits::min()); - auto int32_max = MakeConstantScalar( - hp_dtype, std::numeric_limits::max()); - auto int32_min_t = Full(int32_min, input_shape, hp_dtype); - auto int32_max_t = Full(int32_max, input_shape, hp_dtype); auto dividend = MakeConstantScalar(hp_dtype, 1ll << 31); auto zero_t = Zeros(input_shape, hp_dtype); auto nudged_tensor_t = Add(scaled_tensor, Where(GreaterEqual(scaled_tensor, zero_t), pos_nudge_t, neg_nudge_t)); - auto high32_t = Divide(nudged_tensor_t, dividend); - auto overflow_t = LogicalAnd(Equal(input_tensor, int32_min_t), - Equal(multiplier_expr, int32_min_t)); - return Where(overflow_t, int32_max_t, high32_t); + auto high32_t = Cast(Divide(nudged_tensor_t, dividend), lp_dtype); + + if (possible_to_overflow) { + auto int32_min = MakeConstantScalar( + lp_dtype, std::numeric_limits::min()); + auto int32_max = MakeConstantScalar( + lp_dtype, std::numeric_limits::max()); + auto int32_max_t = Full(int32_max, input_shape, lp_dtype); + auto int32_min_t = Full(int32_min, input_shape, lp_dtype); + + auto overflow_t = LogicalAnd(Equal(input_tensor, int32_min_t), + Equal(multiplier_expr, int32_min_t)); + return Where(overflow_t, int32_max_t, high32_t); + } else { + return high32_t; + } } /* @@ -114,6 +122,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& // Choose high precision datatype to be int64. This is for avoiding overflow // in multiplication of two int32 values. DataType hp_dtype = DataType::Int(64); + DataType lp_dtype = DataType::Int(32); tensor = Cast(tensor, hp_dtype); // 1) Calculating the integer multiplier and integer shift @@ -147,14 +156,14 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& // rounding scheme, which calculates a rounder tensor according to the sign // of values in the tensor to be rounded. auto nearest_rounding_scalar = - [&](const Expr& input_tensor, int right_shift) -> Expr { + [&](const Expr& input_tensor, int right_shift, DataType dtype) -> Expr { int64_t pos_rounding_value = (1ll << (right_shift - 1)); - auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); - auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); - auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); - auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); + auto pos_rounder = MakeConstantScalar(dtype, pos_rounding_value); + auto neg_rounder = MakeConstantScalar(dtype, pos_rounding_value - 1); + auto pos_rounder_t = Full(pos_rounder, input_shape, dtype); + auto neg_rounder_t = Full(neg_rounder, input_shape, dtype); - auto zero_t = Zeros(input_shape, hp_dtype); + auto zero_t = Zeros(input_shape, dtype); return Where( GreaterEqual(input_tensor, zero_t), pos_rounder_t, neg_rounder_t); }; @@ -163,18 +172,25 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& if (rounding == "UPWARD") { round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); } else if (rounding == "TONEAREST") { - round_scalar = nearest_rounding_scalar(scaled_tensor, total_right_shift); + round_scalar = nearest_rounding_scalar(scaled_tensor, total_right_shift, hp_dtype); } else if (rounding == "TFLITE") { auto scalar_t = Full(scalar, input_shape, hp_dtype); + bool possible_to_overflow = + fixed_point_multiplier == std::numeric_limits::min(); auto high32_t = SaturatingRoundingDoublingHigh32( - tensor, scalar_t, scaled_tensor, input_shape); - - auto zero_t = Zeros(input_shape, hp_dtype); - round_scalar = nearest_rounding_scalar(high32_t, right_shift); - scaled_tensor = right_shift > 0 ? Add(high32_t, round_scalar) : high32_t; - auto rshift_expr = MakeConstantScalar(hp_dtype, right_shift); - auto right_shift_t = Full(rshift_expr, input_shape, hp_dtype); - return Cast(RightShift(scaled_tensor, right_shift_t), DataType::Int(32)); + tensor, scalar_t, scaled_tensor, input_shape, possible_to_overflow); + + if (right_shift <= 0) { + scaled_tensor = high32_t; + } else { + auto zero_t = Zeros(input_shape, lp_dtype); + round_scalar = nearest_rounding_scalar(high32_t, right_shift, lp_dtype); + scaled_tensor = Add(high32_t, round_scalar); + auto rshift_expr = MakeConstantScalar(lp_dtype, right_shift); + // auto right_shift_t = Full(rshift_expr, input_shape, lp_dtype); + scaled_tensor = RightShift(scaled_tensor, rshift_expr); + } + return scaled_tensor; } else { LOG(FATAL) << "Rounding mode " << rounding << " not supported."; } diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index a6037670f9322..ec2487ee2dbb2 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -97,8 +97,6 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target shape_dict=shape_dict, dtype_dict=dtype_dict, out_idx=out_idx) - print(mod.asgraph()) - print(mod.astext()) with relay.build_config(opt_level=opt_level): graph, lib, params = relay.build(mod, target, params=params) @@ -1818,7 +1816,7 @@ def test_forward_qnn_mobilenet_v2_net(): tflite_model_buf = f.read() data = list() - bs = 1 + bs = 100 for i in range(bs): np.random.seed(i) # np.random.seed(43) @@ -1827,18 +1825,21 @@ def test_forward_qnn_mobilenet_v2_net(): tflite_tensor_idx = {0:1, 8:101, 9:93, 16:112, 20:122, 28:156, 27:141, 30:159, 31:151, 35:161, 57:74, 62:6, -1:-1} - out_op_idx = 0 + out_op_idx = -1 tflite_output = run_tflite_graph(tflite_model_buf, data, out_idx=tflite_tensor_idx[out_op_idx]) tvm_output = run_tvm_graph(tflite_model_buf, data, 'input', out_idx=out_op_idx) for i in range(bs): - print("verify ", i) tflite_predictions = np.squeeze(tflite_output[i]) tvm_predictions = np.squeeze(tvm_output[i]) - # res = tvm_predictions.flatten() - #for i, x in enumerate(res): - # print(i, x) - tvm.testing.assert_allclose(tvm_predictions, tflite_predictions, - rtol=0, atol=0) + tvm_res = tvm_predictions.flatten() + tflite_res = tflite_predictions.flatten() + print("verify ", i, "diff", np.sum(np.abs(tvm_res - tflite_res))) + + # for i, x in enumerate(tvm_res): + # if (x != tflite_res[i]): + # print(i, x, tflite_res[i]) + # tvm.testing.assert_allclose(tvm_predictions.astype('int32'), tflite_predictions.astype('int32'), + # rtol=0, atol=0) ####################################################################### # Mobilenet V3 Quantized @@ -1915,8 +1916,8 @@ def test_forward_mediapipe_hand_landmark(): # ---- if __name__ == '__main__': #_test_forward_elemwise_quantized(_test_add) - test_forward_qnn_mobilenet_v2_net() - exit() + # test_forward_qnn_mobilenet_v2_net() + # exit() # BatchToSpaceND test_forward_batch_to_space_nd()