From 51e841c1912a551542f2986ecc02103edade39aa Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 May 2021 09:47:12 -0700 Subject: [PATCH 01/16] rename _update_target and document its function --- python/tvm/relay/build_module.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index e134eeeefd09..5eaa6517f247 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -39,7 +39,23 @@ from .backend.vm import VMExecutor -def _update_target(target): +def build_target_by_device_type_map(target): + """Build a map from DLDevice device_type to a Target used with that device. + + At runtime, TVM assigns target code to DLDevices by determining a device_type for each Target. + This function handles this process at compile time and, as a side effect, validates that exactly + one target maps to one device_type. + + Parameters + ---------- + target : Target or str or dict + If a Target or str: assumes that exactly one device type is present in the model. + If a dict: keys are tvm.ndarray.device, values are the targets used for each device. + + Returns + ------- + + """ target = target if target else Target.current() if target is None: raise ValueError("Target is not set in env or passed as argument.") @@ -126,7 +142,7 @@ def build(self, mod, target=None, target_host=None, params=None, executor="graph params : dict The parameters of the final graph. """ - target = _update_target(target) + target = build_target_by_device_type_map(target) target, target_host = Target.check_and_update_host_consist( target, target_host, target_is_dict_key=False ) @@ -179,7 +195,7 @@ def optimize(self, mod, target=None, params=None): params : dict The parameters of the final graph. """ - target = _update_target(target) + target = build_target_by_device_type_map(target) # Setup the params. if params: @@ -307,7 +323,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" "instead of deprecated parameter mod (tvm.relay.function.Function)", DeprecationWarning, ) - target = _update_target(target) + target = build_target_by_device_type_map(target) if isinstance(target_host, (str, Target)): target_host = Target(target_host) elif target_host: @@ -386,7 +402,7 @@ def optimize(mod, target=None, params=None): DeprecationWarning, ) - target = _update_target(target) + target = build_target_by_device_type_map(target) # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub From 76ef467526bcebf873ec0207c0479f5b3d10080e Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 May 2021 09:48:21 -0700 Subject: [PATCH 02/16] make tvm.build return OperatorModule to return multiple outputs --- python/tvm/driver/build_module.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a4df63f225b2..8682cc27e34b 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -24,6 +24,7 @@ import tvm.tir +from tvm.runtime import Module from tvm.runtime import ndarray from tvm.ir import container from tvm.ir import CallingConv @@ -372,12 +373,31 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - return create_csource_crt_metadata_module([rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) - if target_host.kind.name == "llvm": + elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - return create_llvm_crt_metadata_module([rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) + else: + to_return = rt_mod_host + + return OperatorModule.from_module(to_return, ir_module_by_target=target_input_mod, name=name) + + +class OperatorModule(Module): + + @classmethod + def from_module(cls, mod, **kwargs): + # NOTE(areusch): It is generally unsafe to continue using `mod` from this point forward. + # If an exception occurs in cls.__init__, handle will be deleted. For this reason, + # set mod.handle to None. + handle = mod.handle + mod.handle = None + return cls(handle, **kwargs) - return rt_mod_host + def __init__(self, handle, ir_module_by_target=None, name=None): + super(OperatorModule, self).__init__(handle) + self.ir_module_by_target = ir_module_by_target + self.name = name From 66126f4fab82947b9d34b059d63a41817a5199c3 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 May 2021 09:49:04 -0700 Subject: [PATCH 03/16] allow retrieving the var names used in TIR repr --- src/printer/model_library_format_printer.cc | 51 +++++++++++++++++++++ src/printer/text_printer.h | 12 +++++ src/printer/tir_text_printer.cc | 10 ++++ 3 files changed, 73 insertions(+) create mode 100644 src/printer/model_library_format_printer.cc diff --git a/src/printer/model_library_format_printer.cc b/src/printer/model_library_format_printer.cc new file mode 100644 index 000000000000..4834ec6aaefa --- /dev/null +++ b/src/printer/model_library_format_printer.cc @@ -0,0 +1,51 @@ +#include +#include +#include +#include "text_printer.h" + +namespace tvm { +namespace printer { + +class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { + public: + ModelLibraryFormatPrinter(bool show_meta_data, const runtime::TypedPackedFunc& annotate, bool show_warning) : + text_printer_{show_meta_data, annotate, show_warning} {} + + const char* type_key() const override { + return "model_library_format_printer"; + } + + std::string Print(const ObjectRef& node) { + Doc doc; + doc << text_printer_.PrintFinal(node); + return doc.str(); + } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { + if (name == "print") { + return TypedPackedFunc([sptr_to_self, this](ObjectRef node) { return Print(node); }); + } else if (name == "get_var_name") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.size(), 1) << "usage: get_var_name(Var v)"; + + std::string var_name; + if (text_printer_.GetVarName(args[0], &var_name)) { + *rv = var_name; + } + }); + } else { + return PackedFunc(); + } + } + + private: + TextPrinter text_printer_; +}; + +TVM_REGISTER_GLOBAL("tir.ModelLibraryFormatPrinter").set_body_typed( + [](bool show_meta_data, const runtime::TypedPackedFunc& annotate, bool show_warning) { + return ObjectRef(make_object(show_meta_data, annotate, show_warning)); + }); + +} // namespace printer +} // namespace tvm diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 7a529cc0b914..6400c0904a5b 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -256,6 +257,13 @@ class TIRTextPrinter : public StmtFunctor, /*! \brief Print the node */ Doc Print(const ObjectRef& node); + /*! \brief Place into `s` the name used in the preceding Print call for `v`. + * \param v Var instance to check. Must point to a VarNode visited by Print. + * \param s String to receive the name. + * \return true when a name re-mapping was found. + */ + bool GetVarName(::tvm::tir::Var v, std::string* s); + private: /*! \brief whether show meta data */ bool show_meta_; @@ -394,6 +402,10 @@ class TextPrinter { /*! \brief TIR Text Printer */ tir::TIRTextPrinter tir_text_printer_; + bool GetVarName(::tvm::tir::Var v, std::string* s) { + return tir_text_printer_.GetVarName(v, s); + } + Doc PrintFinal(const ObjectRef& node) { Doc doc; if (node->IsInstance()) { diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 04c5ea1cdf99..0fefb0515e49 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -734,5 +734,15 @@ Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) { return doc; } +bool TIRTextPrinter::GetVarName(Var v, std::string* s) { + auto it = memo_var_.find(v); + if (it == memo_var_.end()) { + return false; + } + + *s = it->second.str(); + return true; +} + } // namespace tir } // namespace tvm From 010f8ff309815e8a60dfc76b743490d9406a9223 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 May 2021 09:49:46 -0700 Subject: [PATCH 04/16] add Operator Model Library Format and test --- python/tvm/micro/model_library_format.py | 193 +++++++++++++++--- .../test_micro_model_library_format.py | 47 +++++ 2 files changed, 209 insertions(+), 31 deletions(-) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 1cc3adf9ae07..30ab014eb7f6 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -20,12 +20,19 @@ import datetime import json import os +import pathlib import re import tarfile +import typing +from .._ffi import get_global_func from ..contrib import utils +from ..driver import build_module +from ..runtime import ndarray as _nd +from ..relay import build_module as relay_build_module from ..relay.backend import executor_factory from ..relay import param_dict +from ..tir import expr # This should be kept identical to runtime::symbol::tvm_module_main MAIN_FUNC_NAME_STR = "__tvm_main__" @@ -203,67 +210,191 @@ def _build_function_memory_map(function_metadata): return ret -def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, file_name): - """Export the build artifact in Model Library Format. +def _make_tar(source_dir, tar_file_path): + """Build a tar file from source_dir.""" + with tarfile.open(tar_file_path, "w") as tar_f: + def reset(tarinfo): + tarinfo.uid = tarinfo.gid = 0 + tarinfo.uname = tarinfo.gname = "root" + return tarinfo + + tar_f.add(str(source_dir), arcname=".", filter=reset) - This function creates a .tar archive containing the build artifacts in a standardized - layout. It's intended to allow downstream automation to build TVM artifacts against the C - runtime. + +_GENERATED_VERSION = 2 + + +def _export_graph_model_library_format(mod: executor_factory.GraphExecutorFactoryModule, tempdir: pathlib.Path): + """Export a tvm.relay.build artifact in Model Library Format. Parameters ---------- mod : tvm.relay.backend.executor_factory.ExecutorFactoryModule The return value of tvm.relay.build, which will be exported into Model Library Format. - file_name : str - Path to the .tar archive to generate. - - Returns - ------- - file_name : str - The path to the generated .tar archive. + tempdir : pathlib.Path + Temporary directory to populate with Model Library Format contents. """ - tempdir = utils.tempdir() is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) runtime = ["aot"] if is_aot else ["graph"] metadata = { - "version": 2, + "version": _GENERATED_VERSION, "model_name": mod.libmod_name, "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), "memory": _build_memory_map(mod), "target": {int(k): str(v) for k, v in mod.target.items()}, "runtimes": runtime, + "style": "full-model", } - with open(tempdir.relpath("metadata.json"), "w") as json_f: + with open(tempdir / "metadata.json", "w") as json_f: json.dump(metadata, json_f, indent=2, sort_keys=True) - codegen_dir_path = tempdir.relpath("codegen") - os.mkdir(codegen_dir_path) - _populate_codegen_dir(mod.lib, codegen_dir_path) + codegen_dir = tempdir / "codegen" + codegen_dir.mkdir() + _populate_codegen_dir(mod.lib, codegen_dir) - parameters_dir_path = tempdir.relpath("parameters") - os.mkdir(parameters_dir_path) - param_filename = os.path.join(parameters_dir_path, f"{mod.libmod_name}.params") + parameters_dir = tempdir / "parameters" + parameters_dir.mkdir() + param_filename = parameters_dir / f"{mod.libmod_name}.params" with open(param_filename, "wb") as f: f.write(param_dict.save_param_dict(mod.params)) - with open(tempdir.relpath("relay.txt"), "w") as f: + src_dir = tempdir / "src" + src_dir.mkdir() + with open(src_dir / "relay.txt", "w") as f: f.write(str(mod.ir_mod)) if not is_aot: - graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph")) - os.makedirs(graph_config_dir_path) - with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f: + graph_config_dir = tempdir / "runtime-config" / "graph" + graph_config_dir.mkdir(parents=True) + with open(graph_config_dir / "graph.json", "w") as f: f.write(mod.get_executor_config()) - with tarfile.open(file_name, "w") as tar_f: - def reset(tarinfo): - tarinfo.uid = tarinfo.gid = 0 - tarinfo.uname = tarinfo.gname = "root" - return tarinfo +class NonStaticShapeError(Exception): + """Raised when a shape has elements other than IntImm.""" + + +def _shape_to_size(shape, dtype): + bits_per_item = int(re.match(r"((float)|(int))(?P[0-9]+)", dtype).group('width_bits')) + assert bits_per_item is not None, f"don't know how to compute size of type {dtype}" + total_bits = bits_per_item + for s in shape: + total_bits *= s + + return (total_bits + 7) // 8 + + +def _write_tir_and_build_operator_memory_map(src_dir, targets, ir_module_by_target): + def _eval_shape(param_name, buffer_shape): + shape = [] + for x in buffer_shape: + if not isinstance(x, expr.IntImm): + raise NonStaticShapeError(f"Parameter {param_name} has shape with non-IntImm elements: {buffer_shape}") + shape.append(x.value) + return shape + + memory_map = [] + storage_id = 0 + for target_device_type, target in targets.items(): + ir_mod = ir_module_by_target[target] + printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, None, False) + with open(src_dir / f"tir-{target_device_type}.txt", "w") as f: + f.write(printer["print"](ir_mod)) + + for v in ir_mod.get_global_vars(): + for p, b in ir_mod[v.name_hint].buffer_map.items(): + shape = _eval_shape(p.name, b.shape) + buffer_size_bytes = _shape_to_size(shape, str(b.dtype)) + # NOTE: cannot tell what is an input or output at this point. + map_entry = { + "storage_id": storage_id, + "size_bytes": buffer_size_bytes, + "input_binding": printer["get_var_name"](p), + } + storage_id += 1 + memory_map.append(map_entry) + + return memory_map + + +def _export_operator_model_library_format(mod: build_module.OperatorModule, tempdir): + """Export the result of tvm.build() in Model Library Format. + + Parameters + ---------- + mod : runtime.Module + The Module returned from tvm.build(). + args : list of Buffer or Tensor or Var, optional + The args supplied to tvm.build(). + file_name : str + Path to the .tar archive to generate. + """ + targets = {} + for target in mod.ir_module_by_target.keys(): + if str(target.kind) not in ("llvm", "c"): + raise UnsupportedInModelLibraryFormatError( + f"Operator has non-DSO-exportable target {target!s}, which is not yet supported in " + "Model Library Format") + + targets[int(_nd.device(str(target)).device_type)] = target + + src_dir = tempdir / "src" + src_dir.mkdir() + memory_map = _write_tir_and_build_operator_memory_map(src_dir, targets, mod.ir_module_by_target) + + metadata = { + "version": _GENERATED_VERSION, + "model_name": mod.name, + "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), + "memory": memory_map, + "target": {k: str(v) for k, v in targets.items()}, + "runtimes": [], + "style": "operator", + } + with open(tempdir / "metadata.json", "w") as metadata_f: + json.dump(metadata, metadata_f) + + codegen_dir = tempdir / "codegen" + codegen_dir.mkdir() + _populate_codegen_dir(mod, codegen_dir) + + +ExportableModule = typing.Union[build_module.OperatorModule, + executor_factory.GraphExecutorFactoryModule] + +def export_model_library_format(mod: ExportableModule, file_name: typing.Union[str, pathlib.Path]): + """Export the build artifact in Model Library Format. + + This function creates a .tar archive containing the build artifacts in a standardized + layout. It's intended to allow downstream automation to build TVM artifacts against the C + runtime. + + Parameters + ---------- + mod : ExportableModule + The return value of tvm.build or tvm.relay.build. + file_name : str + Path to the .tar archive to generate. + + Returns + ------- + file_name : str + The path to the generated .tar archive. + """ + file_name = pathlib.Path(file_name) + + tempdir = pathlib.Path(utils.tempdir().temp_dir) + tempdir.mkdir() + + if isinstance(mod, build_module.OperatorModule): + _export_operator_model_library_format(mod, tempdir) + elif isinstance(mod, executor_factory.GraphExecutorFactoryModule): + _export_graph_model_library_format(mod, tempdir) + else: + raise NotImplementedError(f"Don't know how to export module of type {mod.__class__!r}") - tar_f.add(tempdir.temp_dir, arcname=".", filter=reset) + _make_tar(tempdir, file_name) return file_name diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index d2c519da22b5..a877added3cc 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -32,6 +32,53 @@ from tvm.contrib import utils +@tvm.testing.requires_micro +def test_export_operator_model_library_format(): + import tvm.micro as micro + + target = tvm.target.target.micro("host") + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + A = tvm.te.placeholder((2,), dtype="int8") + B = tvm.te.placeholder((1,), dtype="int8") + C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") + sched = tvm.te.create_schedule(C.op) + mod = tvm.build(sched, [A, B, C], tvm.target.Target(target, target), name="add") + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + micro.export_model_library_format(mod, mlf_tar_path) + + tf = tarfile.open(mlf_tar_path) + + extract_dir = temp_dir.relpath("extract") + os.mkdir(extract_dir) + tf.extractall(extract_dir) + + with open(os.path.join(extract_dir, "metadata.json")) as json_f: + metadata = json.load(json_f) + assert metadata["version"] == 2 + assert metadata["model_name"] == "add" + export_datetime = datetime.datetime.strptime( + metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" + ) + assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) + assert metadata["target"] == {"1": str(target)} + assert metadata["memory"] == [ + {"storage_id": 0, "size_bytes": 2, "input_binding": "placeholder_2"}, + {"storage_id": 1, "size_bytes": 1, "input_binding": "placeholder_3"}, + {"storage_id": 2, "size_bytes": 2, "input_binding": "C_1"}, + ] + + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib0.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib1.c")) + + assert len(mod.ir_module_by_target) == 1, f"expect 1 ir_modele_by_target: {ir_module_by_target!r}" + for target, ir_mod in mod.ir_module_by_target.items(): + assert int(tvm.runtime.ndarray.device(str(target)).device_type) == 1 + with open(os.path.join(extract_dir, "src", "tir-1.txt")) as tir_f: + assert tir_f.read() == str(ir_mod) + + def validate_graph_json(extract_dir, factory): with open(os.path.join(extract_dir, "runtime-config", "graph", "graph.json")) as graph_f: graph_json = graph_f.read() From d0aa1806e5cfe2377a3ff54dfd0db873362a778f Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 May 2021 12:34:05 -0700 Subject: [PATCH 05/16] Add pathlib convenience functions to utils.TempDirectory. --- python/tvm/contrib/utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/tvm/contrib/utils.py b/python/tvm/contrib/utils.py index 6451896c6bd1..28f4e7e7e5b5 100644 --- a/python/tvm/contrib/utils.py +++ b/python/tvm/contrib/utils.py @@ -19,6 +19,7 @@ import contextlib import datetime import os +import pathlib import tempfile import threading import shutil @@ -119,6 +120,16 @@ def remove(self): self.TEMPDIRS.remove(self.temp_dir) self.temp_dir = None + @property + def path(self): + return pathlib.Path(self.temp_dir) + + def __div__(self, other): + if not isinstance(other, (str, pathlib.Path)): + raise TypeError("TempDirectory / operator: must supply str or pathlib.Path; got %r" % (other,)) + + return self.path / other + def __del__(self): temp_dirs = getattr(self, "TEMPDIRS", None) if temp_dirs is None: From 408d154067566ee27b89f6ec0a9f2ecab67ad37b Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 May 2021 12:34:41 -0700 Subject: [PATCH 06/16] fix tests --- python/tvm/micro/model_library_format.py | 13 +++++++------ .../unittest/test_micro_model_library_format.py | 6 +++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 30ab014eb7f6..3b8bca0ac52c 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -362,6 +362,7 @@ def _export_operator_model_library_format(mod: build_module.OperatorModule, temp ExportableModule = typing.Union[build_module.OperatorModule, + executor_factory.AOTExecutorFactoryModule, executor_factory.GraphExecutorFactoryModule] def export_model_library_format(mod: ExportableModule, file_name: typing.Union[str, pathlib.Path]): @@ -385,16 +386,16 @@ def export_model_library_format(mod: ExportableModule, file_name: typing.Union[s """ file_name = pathlib.Path(file_name) - tempdir = pathlib.Path(utils.tempdir().temp_dir) - tempdir.mkdir() + tempdir = utils.tempdir() if isinstance(mod, build_module.OperatorModule): - _export_operator_model_library_format(mod, tempdir) - elif isinstance(mod, executor_factory.GraphExecutorFactoryModule): - _export_graph_model_library_format(mod, tempdir) + _export_operator_model_library_format(mod, tempdir.path) + elif isinstance(mod, (executor_factory.AOTExecutorFactoryModule, + executor_factory.GraphExecutorFactoryModule)): + _export_graph_model_library_format(mod, tempdir.path) else: raise NotImplementedError(f"Don't know how to export module of type {mod.__class__!r}") - _make_tar(tempdir, file_name) + _make_tar(tempdir.path, file_name) return file_name diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index a877added3cc..048fc3239135 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -168,7 +168,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ if executor == "graph": validate_graph_json(extract_dir, factory) - with open(os.path.join(extract_dir, "relay.txt")) as relay_f: + with open(os.path.join(extract_dir, "src", "relay.txt")) as relay_f: assert relay_f.read() == str(relay_mod) with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f: @@ -245,7 +245,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ validate_graph_json(extract_dir, factory) - with open(os.path.join(extract_dir, "relay.txt")) as relay_f: + with open(os.path.join(extract_dir, "src", "relay.txt")) as relay_f: assert relay_f.read() == str(relay_mod) with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f: @@ -316,7 +316,7 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 @tvm.testing.requires_micro -def test_export_model(): +def test_export_non_dso_exportable(): module = tvm.support.FrontendTestModule() factory = executor_factory.GraphExecutorFactoryModule( None, tvm.target.target.micro("host"), '"graph_json"', module, "test_module", {}, {} From 2537d3ab5ce2b7ea349218e1a25d63905c2ee207 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 May 2021 12:36:45 -0700 Subject: [PATCH 07/16] black format --- python/tvm/contrib/utils.py | 4 ++- python/tvm/driver/build_module.py | 1 - python/tvm/micro/model_library_format.py | 31 +++++++++++++------ .../test_micro_model_library_format.py | 4 ++- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/python/tvm/contrib/utils.py b/python/tvm/contrib/utils.py index 28f4e7e7e5b5..68c6b3d5bf6b 100644 --- a/python/tvm/contrib/utils.py +++ b/python/tvm/contrib/utils.py @@ -126,7 +126,9 @@ def path(self): def __div__(self, other): if not isinstance(other, (str, pathlib.Path)): - raise TypeError("TempDirectory / operator: must supply str or pathlib.Path; got %r" % (other,)) + raise TypeError( + "TempDirectory / operator: must supply str or pathlib.Path; got %r" % (other,) + ) return self.path / other diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 8682cc27e34b..7f01a5d3cf97 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -387,7 +387,6 @@ def build( class OperatorModule(Module): - @classmethod def from_module(cls, mod, **kwargs): # NOTE(areusch): It is generally unsafe to continue using `mod` from this point forward. diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 3b8bca0ac52c..c155134d2599 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -213,6 +213,7 @@ def _build_function_memory_map(function_metadata): def _make_tar(source_dir, tar_file_path): """Build a tar file from source_dir.""" with tarfile.open(tar_file_path, "w") as tar_f: + def reset(tarinfo): tarinfo.uid = tarinfo.gid = 0 tarinfo.uname = tarinfo.gname = "root" @@ -224,7 +225,9 @@ def reset(tarinfo): _GENERATED_VERSION = 2 -def _export_graph_model_library_format(mod: executor_factory.GraphExecutorFactoryModule, tempdir: pathlib.Path): +def _export_graph_model_library_format( + mod: executor_factory.GraphExecutorFactoryModule, tempdir: pathlib.Path +): """Export a tvm.relay.build artifact in Model Library Format. Parameters @@ -277,7 +280,9 @@ class NonStaticShapeError(Exception): def _shape_to_size(shape, dtype): - bits_per_item = int(re.match(r"((float)|(int))(?P[0-9]+)", dtype).group('width_bits')) + bits_per_item = int( + re.match(r"((float)|(int))(?P[0-9]+)", dtype).group("width_bits") + ) assert bits_per_item is not None, f"don't know how to compute size of type {dtype}" total_bits = bits_per_item for s in shape: @@ -291,7 +296,9 @@ def _eval_shape(param_name, buffer_shape): shape = [] for x in buffer_shape: if not isinstance(x, expr.IntImm): - raise NonStaticShapeError(f"Parameter {param_name} has shape with non-IntImm elements: {buffer_shape}") + raise NonStaticShapeError( + f"Parameter {param_name} has shape with non-IntImm elements: {buffer_shape}" + ) shape.append(x.value) return shape @@ -336,7 +343,8 @@ def _export_operator_model_library_format(mod: build_module.OperatorModule, temp if str(target.kind) not in ("llvm", "c"): raise UnsupportedInModelLibraryFormatError( f"Operator has non-DSO-exportable target {target!s}, which is not yet supported in " - "Model Library Format") + "Model Library Format" + ) targets[int(_nd.device(str(target)).device_type)] = target @@ -361,9 +369,12 @@ def _export_operator_model_library_format(mod: build_module.OperatorModule, temp _populate_codegen_dir(mod, codegen_dir) -ExportableModule = typing.Union[build_module.OperatorModule, - executor_factory.AOTExecutorFactoryModule, - executor_factory.GraphExecutorFactoryModule] +ExportableModule = typing.Union[ + build_module.OperatorModule, + executor_factory.AOTExecutorFactoryModule, + executor_factory.GraphExecutorFactoryModule, +] + def export_model_library_format(mod: ExportableModule, file_name: typing.Union[str, pathlib.Path]): """Export the build artifact in Model Library Format. @@ -390,8 +401,10 @@ def export_model_library_format(mod: ExportableModule, file_name: typing.Union[s if isinstance(mod, build_module.OperatorModule): _export_operator_model_library_format(mod, tempdir.path) - elif isinstance(mod, (executor_factory.AOTExecutorFactoryModule, - executor_factory.GraphExecutorFactoryModule)): + elif isinstance( + mod, + (executor_factory.AOTExecutorFactoryModule, executor_factory.GraphExecutorFactoryModule), + ): _export_graph_model_library_format(mod, tempdir.path) else: raise NotImplementedError(f"Don't know how to export module of type {mod.__class__!r}") diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 048fc3239135..bf79584e0dcc 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -72,7 +72,9 @@ def test_export_operator_model_library_format(): assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib0.c")) assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib1.c")) - assert len(mod.ir_module_by_target) == 1, f"expect 1 ir_modele_by_target: {ir_module_by_target!r}" + assert ( + len(mod.ir_module_by_target) == 1 + ), f"expect 1 ir_modele_by_target: {ir_module_by_target!r}" for target, ir_mod in mod.ir_module_by_target.items(): assert int(tvm.runtime.ndarray.device(str(target)).device_type) == 1 with open(os.path.join(extract_dir, "src", "tir-1.txt")) as tir_f: From c04e8f3862c8576116b93580902956f2f7e4a2e8 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 May 2021 12:37:10 -0700 Subject: [PATCH 08/16] git-clang-format --- src/printer/model_library_format_printer.cc | 37 ++++++++++++--------- src/printer/text_printer.h | 4 +-- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/printer/model_library_format_printer.cc b/src/printer/model_library_format_printer.cc index 4834ec6aaefa..9f4dea5b5dcf 100644 --- a/src/printer/model_library_format_printer.cc +++ b/src/printer/model_library_format_printer.cc @@ -1,6 +1,7 @@ #include #include #include + #include "text_printer.h" namespace tvm { @@ -8,12 +9,12 @@ namespace printer { class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { public: - ModelLibraryFormatPrinter(bool show_meta_data, const runtime::TypedPackedFunc& annotate, bool show_warning) : - text_printer_{show_meta_data, annotate, show_warning} {} + ModelLibraryFormatPrinter(bool show_meta_data, + const runtime::TypedPackedFunc& annotate, + bool show_warning) + : text_printer_{show_meta_data, annotate, show_warning} {} - const char* type_key() const override { - return "model_library_format_printer"; - } + const char* type_key() const override { return "model_library_format_printer"; } std::string Print(const ObjectRef& node) { Doc doc; @@ -23,16 +24,17 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { if (name == "print") { - return TypedPackedFunc([sptr_to_self, this](ObjectRef node) { return Print(node); }); + return TypedPackedFunc( + [sptr_to_self, this](ObjectRef node) { return Print(node); }); } else if (name == "get_var_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.size(), 1) << "usage: get_var_name(Var v)"; + ICHECK_EQ(args.size(), 1) << "usage: get_var_name(Var v)"; - std::string var_name; - if (text_printer_.GetVarName(args[0], &var_name)) { - *rv = var_name; - } - }); + std::string var_name; + if (text_printer_.GetVarName(args[0], &var_name)) { + *rv = var_name; + } + }); } else { return PackedFunc(); } @@ -42,10 +44,13 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { TextPrinter text_printer_; }; -TVM_REGISTER_GLOBAL("tir.ModelLibraryFormatPrinter").set_body_typed( - [](bool show_meta_data, const runtime::TypedPackedFunc& annotate, bool show_warning) { - return ObjectRef(make_object(show_meta_data, annotate, show_warning)); - }); +TVM_REGISTER_GLOBAL("tir.ModelLibraryFormatPrinter") + .set_body_typed([](bool show_meta_data, + const runtime::TypedPackedFunc& annotate, + bool show_warning) { + return ObjectRef( + make_object(show_meta_data, annotate, show_warning)); + }); } // namespace printer } // namespace tvm diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 6400c0904a5b..0332a2d539d2 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -402,9 +402,7 @@ class TextPrinter { /*! \brief TIR Text Printer */ tir::TIRTextPrinter tir_text_printer_; - bool GetVarName(::tvm::tir::Var v, std::string* s) { - return tir_text_printer_.GetVarName(v, s); - } + bool GetVarName(::tvm::tir::Var v, std::string* s) { return tir_text_printer_.GetVarName(v, s); } Doc PrintFinal(const ObjectRef& node) { Doc doc; From 9fff1023267877bf1112c8cf9571f525a079278f Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 May 2021 12:45:00 -0700 Subject: [PATCH 09/16] pylint fixes --- python/tvm/driver/build_module.py | 2 ++ python/tvm/micro/model_library_format.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 7f01a5d3cf97..d46bd5c833a7 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -387,6 +387,8 @@ def build( class OperatorModule(Module): + """Wraps the Module returned by tvm.build() and captures additional outputs of that function.""" + @classmethod def from_module(cls, mod, **kwargs): # NOTE(areusch): It is generally unsafe to continue using `mod` from this point forward. diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index c155134d2599..ceb285786cc7 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -29,7 +29,6 @@ from ..contrib import utils from ..driver import build_module from ..runtime import ndarray as _nd -from ..relay import build_module as relay_build_module from ..relay.backend import executor_factory from ..relay import param_dict from ..tir import expr From b007dfadab5c64a64f83b3b9132c94021cdf4180 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 18 May 2021 12:45:06 -0700 Subject: [PATCH 10/16] add asf header --- src/printer/model_library_format_printer.cc | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/printer/model_library_format_printer.cc b/src/printer/model_library_format_printer.cc index 9f4dea5b5dcf..c3eb99717598 100644 --- a/src/printer/model_library_format_printer.cc +++ b/src/printer/model_library_format_printer.cc @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + #include #include #include From 5c20bd61a4ea5245882d0623c08917847fd58309 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 19 May 2021 17:14:31 -0700 Subject: [PATCH 11/16] change memory map to make more sense, fix tests --- python/tvm/micro/model_library_format.py | 20 ++++++++++--------- .../test_micro_model_library_format.py | 19 ++++++++++++------ 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index ceb285786cc7..711699e39e70 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -301,8 +301,7 @@ def _eval_shape(param_name, buffer_shape): shape.append(x.value) return shape - memory_map = [] - storage_id = 0 + memory_map = {} for target_device_type, target in targets.items(): ir_mod = ir_module_by_target[target] printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, None, False) @@ -310,17 +309,20 @@ def _eval_shape(param_name, buffer_shape): f.write(printer["print"](ir_mod)) for v in ir_mod.get_global_vars(): + map_entry = [] for p, b in ir_mod[v.name_hint].buffer_map.items(): shape = _eval_shape(p.name, b.shape) buffer_size_bytes = _shape_to_size(shape, str(b.dtype)) # NOTE: cannot tell what is an input or output at this point. - map_entry = { - "storage_id": storage_id, - "size_bytes": buffer_size_bytes, - "input_binding": printer["get_var_name"](p), - } - storage_id += 1 - memory_map.append(map_entry) + map_entry.append( + { + "size_bytes": buffer_size_bytes, + "shape": [int(x) for x in b.shape], + "dtype": b.dtype, + "input_binding": printer["get_var_name"](p), + } + ) + memory_map[v.name_hint] = map_entry return memory_map diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index bf79584e0dcc..da0857e2ee29 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -63,18 +63,25 @@ def test_export_operator_model_library_format(): ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) assert metadata["target"] == {"1": str(target)} - assert metadata["memory"] == [ - {"storage_id": 0, "size_bytes": 2, "input_binding": "placeholder_2"}, - {"storage_id": 1, "size_bytes": 1, "input_binding": "placeholder_3"}, - {"storage_id": 2, "size_bytes": 2, "input_binding": "C_1"}, - ] + + assert metadata["memory"]["add"][0]["dtype"] == "int8" + assert metadata["memory"]["add"][0]["shape"] == [2] + assert metadata["memory"]["add"][0]["size_bytes"] == 2 + + assert metadata["memory"]["add"][1]["dtype"] == "int8" + assert metadata["memory"]["add"][1]["shape"] == [1] + assert metadata["memory"]["add"][1]["size_bytes"] == 1 + + assert metadata["memory"]["add"][2]["dtype"] == "int8" + assert metadata["memory"]["add"][2]["shape"] == [2] + assert metadata["memory"]["add"][2]["size_bytes"] == 2 assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib0.c")) assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib1.c")) assert ( len(mod.ir_module_by_target) == 1 - ), f"expect 1 ir_modele_by_target: {ir_module_by_target!r}" + ), f"expect 1 ir_model_by_target: {ir_module_by_target!r}" for target, ir_mod in mod.ir_module_by_target.items(): assert int(tvm.runtime.ndarray.device(str(target)).device_type) == 1 with open(os.path.join(extract_dir, "src", "tir-1.txt")) as tir_f: From 7b1ef1aa9413f14e7f82e811305f43218bc93aaa Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 21 Jun 2021 08:03:25 -0700 Subject: [PATCH 12/16] address giuseros comments --- python/tvm/micro/model_library_format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 711699e39e70..ea44cfdca06d 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -225,7 +225,7 @@ def reset(tarinfo): def _export_graph_model_library_format( - mod: executor_factory.GraphExecutorFactoryModule, tempdir: pathlib.Path + mod: executor_factory.ExecutorFactoryModule, tempdir: pathlib.Path ): """Export a tvm.relay.build artifact in Model Library Format. From 32abcf2096debfe8fefc05c2c6f34682e3282b92 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 28 Jun 2021 17:12:24 -0700 Subject: [PATCH 13/16] align GetVarName with future TypedPackedFunc --- src/printer/model_library_format_printer.cc | 22 +++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/printer/model_library_format_printer.cc b/src/printer/model_library_format_printer.cc index c3eb99717598..0f5d080f26b9 100644 --- a/src/printer/model_library_format_printer.cc +++ b/src/printer/model_library_format_printer.cc @@ -21,6 +21,8 @@ #include #include +#include + #include "text_printer.h" namespace tvm { @@ -41,19 +43,23 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { return doc.str(); } + TVMRetValue GetVarName(tir::Var var) { + TVMRetValue rv; + std::string var_name; + if (text_printer_.GetVarName(var, &var_name)) { + rv = var_name; + } + + return rv; + } + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { if (name == "print") { return TypedPackedFunc( [sptr_to_self, this](ObjectRef node) { return Print(node); }); } else if (name == "get_var_name") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.size(), 1) << "usage: get_var_name(Var v)"; - - std::string var_name; - if (text_printer_.GetVarName(args[0], &var_name)) { - *rv = var_name; - } - }); + return TypedPackedFunc( + [sptr_to_self, this](tir::Var var) { return GetVarName(var); }); } else { return PackedFunc(); } From be680af3e3bb92042012f00a0068963d3ea810aa Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 28 Jun 2021 17:26:29 -0700 Subject: [PATCH 14/16] fix test --- tests/python/unittest/test_micro_model_library_format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index f3e20ad48ea1..e1a41ec9e67b 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -56,7 +56,7 @@ def test_export_operator_model_library_format(): with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 2 + assert metadata["version"] == 3 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" From ed3008d33a7684088752ca23615e3c8036f3ed9b Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 29 Jun 2021 13:29:48 -0700 Subject: [PATCH 15/16] clang-format --- src/printer/model_library_format_printer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/printer/model_library_format_printer.cc b/src/printer/model_library_format_printer.cc index 0f5d080f26b9..17ba84e68df4 100644 --- a/src/printer/model_library_format_printer.cc +++ b/src/printer/model_library_format_printer.cc @@ -59,7 +59,7 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { [sptr_to_self, this](ObjectRef node) { return Print(node); }); } else if (name == "get_var_name") { return TypedPackedFunc( - [sptr_to_self, this](tir::Var var) { return GetVarName(var); }); + [sptr_to_self, this](tir::Var var) { return GetVarName(var); }); } else { return PackedFunc(); } From c200ef5f08420a4997dccec8ff572ed5b454937d Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 30 Jun 2021 08:58:25 -0700 Subject: [PATCH 16/16] rev model library format to v4 (bad merge) --- python/tvm/micro/model_library_format.py | 2 +- tests/python/unittest/test_micro_model_library_format.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index edda20baf415..87c067051f82 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -225,7 +225,7 @@ def reset(tarinfo): tar_f.add(str(source_dir), arcname=".", filter=reset) -_GENERATED_VERSION = 3 +_GENERATED_VERSION = 4 def _export_graph_model_library_format( diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index e1a41ec9e67b..246c0336a001 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -56,7 +56,7 @@ def test_export_operator_model_library_format(): with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 3 + assert metadata["version"] == 4 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -141,7 +141,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 3 + assert metadata["version"] == 4 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -221,7 +221,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 3 + assert metadata["version"] == 4 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -300,7 +300,7 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 3 + assert metadata["version"] == 4 assert metadata["model_name"] == "qnn_conv2d" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ"