Skip to content

Commit

Permalink
Decoupling AOT from graph memory planner
Browse files Browse the repository at this point in the history
In this PR we are decoupling AOT from the Graph Memory Planner. Since
AOT has the runner expressed in TIR we can get rid of the GMP in relay
and use the Storage Rewrite Pass to do memory planning on the runner
function. This also sorts out the issue mentioned in apache#8062

Change-Id: I6e33fadbf0462edf0366ee37e84ffde26123d3cb
  • Loading branch information
Giuseppe Rossini committed May 20, 2021
1 parent c197531 commit bbf3ef5
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 17 deletions.
7 changes: 2 additions & 5 deletions python/tvm/relay/backend/executor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,20 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
This holds a map function names to their information
"""

def __init__(
self, ir_mod, target, lowered_funcs, libmod, libmod_name, params, function_metadata
):
def __init__(self, ir_mod, target, libmod, libmod_name, params, function_metadata):
self.ir_mod = ir_mod
self.target = target
self.lib = libmod
self.libmod_name = libmod_name
self.params = params
self.lowered_funcs = lowered_funcs
self.iter_cnt = 0
self.function_metadata = function_metadata

def get_params(self):
return self.params

def get_executor_config(self):
return self.lowered_funcs
return None

def get_lib(self):
return self.lib
Expand Down
5 changes: 2 additions & 3 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(self):
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
self._get_function_metadata = self.mod["get_function_metadata"]
self._get_irmodule = self.mod["get_irmodule"]

def build(self, mod, target=None, target_host=None, params=None, executor="graph"):
"""
Expand Down Expand Up @@ -152,7 +151,7 @@ def build(self, mod, target=None, target_host=None, params=None, executor="graph
# Get artifacts
mod = self.get_module()
params = self.get_params()
executor_config = self.get_graph_json() if executor == "graph" else self._get_irmodule()
executor_config = self.get_graph_json() if executor == "graph" else None

return executor_config, mod, params

Expand Down Expand Up @@ -337,7 +336,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"

if executor == "aot":
executor_factory = _executor_factory.AOTExecutorFactoryModule(
ir_mod, target, executor_config, runtime_mod, mod_name, params, func_metadata
ir_mod, target, runtime_mod, mod_name, params, func_metadata
)
elif executor == "graph":
executor_factory = _executor_factory.GraphExecutorFactoryModule(
Expand Down
26 changes: 20 additions & 6 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/transform.h>

#include <algorithm>
#include <list>
Expand All @@ -48,7 +49,10 @@ using IntegerArray = Array<Integer>;
using TargetsMap = std::unordered_map<int, Target>;
using StorageMap = std::unordered_map<Expr, std::vector<std::vector<int>>, runtime::ObjectPtrHash,
runtime::ObjectPtrEqual>;

/**
* This is an on demand allocator for AOT. A new temporary
* (storage allocator identifier) is allocated for each operation.
*/
class AOTOnDemandAllocator : public ExprVisitor {
public:
// run the visitor on a function.
Expand Down Expand Up @@ -678,6 +682,8 @@ class AOTExecutorCodegen : public ExprVisitor {
// so we don't pay the price of allocation for every inference
if (!allocated[sid]) {
body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body);
body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"),
body);
}
allocated[sid] = true;
}
Expand Down Expand Up @@ -722,6 +728,7 @@ class AOTExecutorCodegen : public ExprVisitor {

/*! \brief plan memory of device result */
StorageMap storage_device_map_;
/*! \brief mapping sid -> tir::Var */
std::unordered_map<int, te::Var> sids_table_;
/*! \brief lowered funcs */
std::unordered_map<std::string, IRModule> lowered_funcs_;
Expand Down Expand Up @@ -791,14 +798,21 @@ class AOTExecutorCodegen : public ExprVisitor {
}
ret.external_mods = compile_engine_->LowerExternalFunctions();

// Build the TIR IRModule
Map<GlobalVar, BaseFunc> symbol_map;
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
IRModule mod_run(symbol_map);

// Apply storage rewrite pass to the runner function to do memory planning
auto storage_rewrite = tir::transform::StorageRewrite();
mod_run = storage_rewrite(mod_run);

// Update the lowered functions
auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_str]->Add(
GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
ret.lowered_funcs[target_host_str]->Update(mod_run);
} else {
Map<GlobalVar, BaseFunc> symbol_map;
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map));
ret.lowered_funcs.Set(target_host_str, mod_run);
}
ret.function_metadata = std::move(function_metadata_);
ret.metadata =
Expand Down
40 changes: 40 additions & 0 deletions tests/python/relay/aot/aot_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,46 @@
from tvm.micro import export_model_library_format


def convert_to_relay(
tflite_model_buf,
input_data,
input_node,
):
""" Convert a tflite model buffer in a Relay module """

def convert_to_list(x):
if not isinstance(x, list):
x = [x]
return x

# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite.Model

tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite

tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except ImportError:
raise ImportError("The tflite package must be installed")

input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)

shape_dict = {}
dtype_dict = {}
for i, e in enumerate(input_node):
shape_dict[e] = input_data[i].shape
dtype_dict[e] = input_data[i].dtype.name

mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
)
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
return mod, params


def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout):
"""
This method runs a process and logs the output to both a log file and stdout
Expand Down
23 changes: 20 additions & 3 deletions tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def test_nested_tuples(use_calculated_workspaces):
x3 = x2 + relay.const(1.0)
x4 = x3 + relay.const(1.0)
out = relay.Tuple([x1, relay.Tuple([relay.Tuple([x2, x3]), x4])])

func = relay.Function([x], out)

x_data = np.random.uniform(size=(10,)).astype(np.float32)
Expand All @@ -181,7 +180,6 @@ def test_nested_tuples(use_calculated_workspaces):
@pytest.mark.parametrize("use_calculated_workspaces", [True, False])
def test_tuple_getitem(use_calculated_workspaces):
func = relay.Function([], relay.TupleGetItem(relay.Tuple([relay.const(1), relay.const(2)]), 0))
print(func)
output_list = generate_ref_data(func, {})
input_list = []
compile_and_run(func, input_list, output_list, use_calculated_workspaces)
Expand Down Expand Up @@ -241,7 +239,6 @@ def test_tuple_output(use_calculated_workspaces):
c = relay.TupleGetItem(y, 2)
out = relay.Tuple([a, b])
func = relay.Function([x], out)
print(func)
x_data = np.random.rand(6, 9).astype("float32")
inputs = {"x": x_data}
output_list = generate_ref_data(func, inputs)
Expand Down Expand Up @@ -367,5 +364,25 @@ def test_byoc_utvm(use_calculated_workspaces):
compile_and_run(mod, input_list, output_list, use_calculated_workspaces)


def test_quant_mobilenet_tfl():
import tvm.relay.testing.tf as tf_testing

tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/"
"models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
"mobilenet_v1_1.0_224_quant.tflite",
)
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data_shape = (1, 224, 224, 3)
in_min, in_max = (0, 255)
data = np.random.randint(in_min, high=in_max, size=data_shape, dtype="uint8")
mod, params = convert_to_relay(tflite_model_buf, data, "input")
inputs = {"input": data}
output_list = generate_ref_data(mod, inputs, params)
input_list = [inputs["input"]]
compile_and_run(mod, input_list, output_list, True, params)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit bbf3ef5

Please sign in to comment.