diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py index 2189b6ad3579d..4ed76f4b6366c 100644 --- a/python/tvm/relay/backend/executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -85,15 +85,12 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule): This holds a map function names to their information """ - def __init__( - self, ir_mod, target, lowered_funcs, libmod, libmod_name, params, function_metadata - ): + def __init__(self, ir_mod, target, libmod, libmod_name, params, function_metadata): self.ir_mod = ir_mod self.target = target self.lib = libmod self.libmod_name = libmod_name self.params = params - self.lowered_funcs = lowered_funcs self.iter_cnt = 0 self.function_metadata = function_metadata @@ -101,7 +98,7 @@ def get_params(self): return self.params def get_executor_config(self): - return self.lowered_funcs + return None def get_lib(self): return self.lib diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 1e3d853cb2ec7..e134eeeefd09b 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -84,7 +84,6 @@ def __init__(self): self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] self._get_function_metadata = self.mod["get_function_metadata"] - self._get_irmodule = self.mod["get_irmodule"] def build(self, mod, target=None, target_host=None, params=None, executor="graph"): """ @@ -152,7 +151,7 @@ def build(self, mod, target=None, target_host=None, params=None, executor="graph # Get artifacts mod = self.get_module() params = self.get_params() - executor_config = self.get_graph_json() if executor == "graph" else self._get_irmodule() + executor_config = self.get_graph_json() if executor == "graph" else None return executor_config, mod, params @@ -337,7 +336,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" if executor == "aot": executor_factory = _executor_factory.AOTExecutorFactoryModule( - ir_mod, target, executor_config, runtime_mod, mod_name, params, func_metadata + ir_mod, target, runtime_mod, mod_name, params, func_metadata ) elif executor == "graph": executor_factory = _executor_factory.GraphExecutorFactoryModule( diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 1d4f0ad05b8dd..07e2d7f978594 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -48,7 +49,10 @@ using IntegerArray = Array; using TargetsMap = std::unordered_map; using StorageMap = std::unordered_map>, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>; - +/** + * This is an on demand allocator for AOT. A new temporary + * (storage allocator identifier) is allocated for each operation. + */ class AOTOnDemandAllocator : public ExprVisitor { public: // run the visitor on a function. @@ -678,6 +682,8 @@ class AOTExecutorCodegen : public ExprVisitor { // so we don't pay the price of allocation for every inference if (!allocated[sid]) { body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body); + body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"), + body); } allocated[sid] = true; } @@ -722,6 +728,7 @@ class AOTExecutorCodegen : public ExprVisitor { /*! \brief plan memory of device result */ StorageMap storage_device_map_; + /*! \brief mapping sid -> tir::Var */ std::unordered_map sids_table_; /*! \brief lowered funcs */ std::unordered_map lowered_funcs_; @@ -791,14 +798,21 @@ class AOTExecutorCodegen : public ExprVisitor { } ret.external_mods = compile_engine_->LowerExternalFunctions(); + // Build the TIR IRModule + Map symbol_map; + symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func); + IRModule mod_run(symbol_map); + + // Apply storage rewrite pass to the runner function to do memory planning + auto storage_rewrite = tir::transform::StorageRewrite(); + mod_run = storage_rewrite(mod_run); + + // Update the lowered functions auto target_host_str = target_host_->str(); if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { - ret.lowered_funcs[target_host_str]->Add( - GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func); + ret.lowered_funcs[target_host_str]->Update(mod_run); } else { - Map symbol_map; - symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func); - ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map)); + ret.lowered_funcs.Set(target_host_str, mod_run); } ret.function_metadata = std::move(function_metadata_); ret.metadata = diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 8c7aefe70d091..e7358828c8f81 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -36,6 +36,46 @@ from tvm.micro import export_model_library_format +def convert_to_relay( + tflite_model_buf, + input_data, + input_node, +): + """ Convert a tflite model buffer in a Relay module """ + + def convert_to_list(x): + if not isinstance(x, list): + x = [x] + return x + + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) + except ImportError: + raise ImportError("The tflite package must be installed") + + input_data = convert_to_list(input_data) + input_node = convert_to_list(input_node) + + shape_dict = {} + dtype_dict = {} + for i, e in enumerate(input_node): + shape_dict[e] = input_data[i].shape + dtype_dict[e] = input_data[i].dtype.name + + mod, params = relay.frontend.from_tflite( + tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict + ) + mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) + return mod, params + + def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout): """ This method runs a process and logs the output to both a log file and stdout diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 501127463f5d9..57641e093257f 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -168,7 +168,6 @@ def test_nested_tuples(use_calculated_workspaces): x3 = x2 + relay.const(1.0) x4 = x3 + relay.const(1.0) out = relay.Tuple([x1, relay.Tuple([relay.Tuple([x2, x3]), x4])]) - func = relay.Function([x], out) x_data = np.random.uniform(size=(10,)).astype(np.float32) @@ -181,7 +180,6 @@ def test_nested_tuples(use_calculated_workspaces): @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) def test_tuple_getitem(use_calculated_workspaces): func = relay.Function([], relay.TupleGetItem(relay.Tuple([relay.const(1), relay.const(2)]), 0)) - print(func) output_list = generate_ref_data(func, {}) input_list = [] compile_and_run(func, input_list, output_list, use_calculated_workspaces) @@ -241,7 +239,6 @@ def test_tuple_output(use_calculated_workspaces): c = relay.TupleGetItem(y, 2) out = relay.Tuple([a, b]) func = relay.Function([x], out) - print(func) x_data = np.random.rand(6, 9).astype("float32") inputs = {"x": x_data} output_list = generate_ref_data(func, inputs) @@ -367,5 +364,25 @@ def test_byoc_utvm(use_calculated_workspaces): compile_and_run(mod, input_list, output_list, use_calculated_workspaces) +def test_quant_mobilenet_tfl(): + import tvm.relay.testing.tf as tf_testing + + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/" + "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", + "mobilenet_v1_1.0_224_quant.tflite", + ) + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data_shape = (1, 224, 224, 3) + in_min, in_max = (0, 255) + data = np.random.randint(in_min, high=in_max, size=data_shape, dtype="uint8") + mod, params = convert_to_relay(tflite_model_buf, data, "input") + inputs = {"input": data} + output_list = generate_ref_data(mod, inputs, params) + input_list = [inputs["input"]] + compile_and_run(mod, input_list, output_list, True, params) + + if __name__ == "__main__": pytest.main([__file__])