From 58ed6dc9be9e96afe5e13578a3963a185fead488 Mon Sep 17 00:00:00 2001 From: "meng.tong" Date: Wed, 26 Feb 2025 18:40:46 +0800 Subject: [PATCH 1/5] remove relay --- .../tvm/contrib/msc/core/codegen/codegen.py | 43 - .../contrib/msc/core/frontend/translate.py | 114 -- .../tvm/contrib/msc/core/transform/pattern.py | 257 --- .../contrib/msc/core/transform/transform.py | 46 +- .../tensorflow/frontend/translate.py | 23 +- .../msc/framework/torch/frontend/translate.py | 43 +- src/contrib/msc/core/ir/graph_builder.cc | 638 ++----- src/contrib/msc/core/ir/graph_builder.h | 207 +- .../msc/core/transform/bind_named_params.cc | 1 - .../msc/core/transform/set_expr_name.cc | 223 --- src/contrib/msc/core/utils.cc | 19 +- src/contrib/msc/core/utils.h | 26 +- .../framework/tensorrt/transform_tensorrt.cc | 1 + .../test_msc/test_translate_tensorflow.py | 1691 ----------------- .../contrib/test_msc/test_translate_torch.py | 281 ++- 15 files changed, 323 insertions(+), 3290 deletions(-) delete mode 100644 tests/python/contrib/test_msc/test_translate_tensorflow.py diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py index 7e3ddd5e07d4..96c9c23dfd9d 100644 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ b/python/tvm/contrib/msc/core/codegen/codegen.py @@ -25,7 +25,6 @@ from tvm.relax import PyExprVisitor from tvm.contrib.msc.core import transform as msc_transform from tvm.contrib.msc.core.ir import MSCGraph, MSCTensor -from tvm.contrib.msc.core.frontend import from_relay from tvm.contrib.msc.core import utils as msc_utils @@ -216,45 +215,3 @@ def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModul if plugin: model_args = model_args + [plugin] return codegen.load(model_args, pre_load=_save_weights, post_load=_post_proc) - - -def relay_to_relax( - relay_mod: tvm.IRModule, - params: Optional[Dict[str, tvm.nd.array]] = None, - trans_config: Optional[Dict[str, str]] = None, - build_config: Optional[Dict[str, str]] = None, - opt_config: Optional[Dict[str, str]] = None, - build_folder: msc_utils.MSCDirectory = None, -) -> tvm.IRModule: - """Change relay IRModule to relax MSCGraph. - - Parameters - ---------- - relay_mod: IRModule - The IRModule of relay. - params: dict of - The parameters of the IRModule. - trans_config: dict - The config for transform IRModule. - build_config: dict - The config for build MSCGraph. - opt_config: dict - The config for optimize the relay before translate. - build_folder: MSCDirectory - The folder for saving scripts and datas. - - Returns - ------- - relax_mod: IRModule - The IRModule of relax. - """ - - graph, weights = from_relay( - relay_mod, - params, - trans_config=trans_config, - build_config=build_config, - opt_config=opt_config, - ) - - return to_relax(graph, weights, codegen_config={"from_relay": True}, build_folder=build_folder) diff --git a/python/tvm/contrib/msc/core/frontend/translate.py b/python/tvm/contrib/msc/core/frontend/translate.py index 8e9bb0cf00d7..00621adfdc65 100644 --- a/python/tvm/contrib/msc/core/frontend/translate.py +++ b/python/tvm/contrib/msc/core/frontend/translate.py @@ -22,9 +22,6 @@ from tvm.relax.transform import BindParams from tvm.relax import PyExprVisitor from tvm.relax.backend.pattern_registry import get_patterns_with_prefix -from tvm.relay.expr_functor import ExprVisitor -from tvm.relay.build_module import bind_params_by_name -from tvm.relay import dataflow_pattern as relay_pattern from tvm.contrib.msc.core import transform as msc_transform from tvm.contrib.msc.core import _ffi_api from tvm.contrib.msc.core import utils as msc_utils @@ -170,117 +167,6 @@ def from_relax( return graph, normalize_weights(t_weights, graph) -def get_relay_patterns( - mod: tvm.IRModule, - entry_name: str = "main", -) -> List[Tuple[str, relay_pattern.DFPattern, callable]]: - """Filter relay patterns based on mod. - - Parameters - ---------- - mod: IRModule - The IRModule of relay. - entry_name: str - The entry name. - - Returns - ------- - patterns: list - The useful patterns for relay - """ - - class OpExtractor(ExprVisitor): - """Extract ops from expr.""" - - def extract(self, expr): - self._optypes = set() - super().visit(expr) - return self._optypes - - def visit_call(self, expr): - super().visit_call(expr) - if isinstance(expr.op, tvm.ir.Op): - self._optypes.add(expr.op.name) - - op_names = OpExtractor().extract(mod[entry_name]) - skip_tags, patterns = set(), list(tvm.relay.op.contrib.get_pattern_table("msc")) - if "nn.conv1d" not in op_names or "add" not in op_names: - skip_tags.add("msc.conv1d_bias") - if "nn.conv2d" not in op_names or "add" not in op_names: - skip_tags.add("msc.conv2d_bias") - if "nn.batch_matmul" not in op_names or "add" not in op_names: - skip_tags.add("msc.linear_bias") - if "nn.batch_matmul" not in op_names: - skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.linear")) - if "nn.dense" not in op_names: - skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.matmul")) - if "take" not in op_names: - skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.embedding")) - if "erf" not in op_names: - skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.gelu")) - valid_patterns = [p for p in patterns if p[0] not in skip_tags] - return valid_patterns - - -def from_relay( - mod: tvm.IRModule, - params: Optional[Dict[str, tvm.nd.array]] = None, - trans_config: Optional[Dict[str, str]] = None, - build_config: Optional[Dict[str, str]] = None, - opt_config: Optional[Dict[str, str]] = None, -) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]: - """Change IRModule to MSCGraph. - - Parameters - ---------- - mod: IRModule - The IRModule of relay. - params: dict of - The parameters of the IRModule. - trans_config: dict - The config for transform IRModule. - build_config: dict - The config for build MSCGraph. - opt_config: dict - The config for optimize the relay before translate. - - Returns - ------- - graph: tvm.contrib.msc.core.ir.MSCGraph - The translated graph. - weights: dict of - The weights from the IRModule. - """ - - trans_config = msc_utils.copy_dict(trans_config) - build_config = msc_utils.copy_dict(build_config) - opt_config = msc_utils.copy_dict(opt_config) - # TODO(tong.meng): optimize before translate? - opt_level = opt_config.get("opt_level", 0) - if params: - mod["main"] = bind_params_by_name(mod["main"], params) - if opt_level > 0: - target = opt_config.get("target", "llvm") - disabled_pass = opt_config.get("disabled_pass", []) + [ - "SimplifyInference", - "CanonicalizeOps", - "FuseOps", - "AlterOpLayout", - ] - with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): - mod, params = tvm.relay.optimize(mod, target=target, params=params) - patterns = get_relay_patterns(mod) - passes = [ - tvm.relay.transform.InferType(), - tvm.relay.transform.MergeComposite(patterns), - msc_transform.SetExprName(as_relax=False), - ] - mod = tvm.transform.Sequential(passes)(mod) - graph = _ffi_api.BuildFromRelay(mod, "main", msc_utils.dump_dict(build_config)) - t_weights = _ffi_api.GetRelayWeights(mod, "main") - return graph, normalize_weights(t_weights, graph) - - @tvm.relax.expr_functor.visitor class BYOCChecker(PyExprVisitor): """Checker to check if any non-target ops exist""" diff --git a/python/tvm/contrib/msc/core/transform/pattern.py b/python/tvm/contrib/msc/core/transform/pattern.py index 135bac64ae80..1ef0076794ae 100644 --- a/python/tvm/contrib/msc/core/transform/pattern.py +++ b/python/tvm/contrib/msc/core/transform/pattern.py @@ -22,11 +22,9 @@ import tvm from tvm.relax.dpl import pattern as relax_pattern -from tvm.relay import dataflow_pattern as relay_pattern from tvm.relax.transform import PatternCheckContext from tvm.relax.backend.pattern_registry import register_patterns -from tvm.relay.op.contrib.register import register_pattern_table from tvm.contrib.msc.core.utils.namespace import MSCMap, MSCKey from tvm.contrib.msc.core import utils as msc_utils from tvm.contrib.msc.core import _ffi_api @@ -621,258 +619,3 @@ def _check_opt_relax_linear_bias(context: PatternCheckContext) -> bool: ), ] ) - - -# TODO(tong.meng): support patterns after optimize -@register_pattern_table("msc") -def pattern_table(): - """Returns list of triples describing the name, dataflow pattern and predicate for all - the MSC-supported operators.""" - - def make_relay_conv_bias_pattern( - op_name: str, optimized: bool = False - ) -> relay_pattern.DFPattern: - """A simple utility to create patterns for an operation fused with bias. - - Parameters - ---------- - op_name: str - The name of a Relay op, such as "relay.nn.conv2d" - optimized: bool - Whether the relay is optimized - - Returns - ------- - pattern: tvm.relay.dataflow_pattern.DFPattern - The resulting pattern describing a conv_bias operation - """ - - data = relay_pattern.wildcard() - weight = relay_pattern.is_constant() - bias = relay_pattern.is_constant() - conv = relay_pattern.is_op(op_name)(data, weight) - if optimized: - out = relay_pattern.is_op("add")(conv, bias) - else: - out = relay_pattern.is_op("nn.bias_add")(conv, bias) - return out - - def _check_relay_conv_bias(call: tvm.relay.Expr) -> bool: - """Check if conv_bias fuse pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - if call.op.name == "nn.bias_add": - bias = call.args[1] - return len(bias.checked_type.shape) == 1 - if call.op.name == "add": - return True - return False - - def make_relay_linear_pattern(optimized: bool = False) -> relay_pattern.DFPattern: - """A simple utility to create patterns for linear. - - Parameters - ---------- - optimized: bool - Whether the relay is optimized - - Returns - ------- - pattern: tvm.relay.dataflow_pattern.DFPattern - The resulting pattern describing a linear operation - """ - - if optimized: - data = relay_pattern.wildcard() - weight = relay_pattern.is_constant() - broadcast_data = relay_pattern.is_op("broadcast_to")(data) - reshape_data = relay_pattern.is_op("reshape")(broadcast_data) - batch_matmul = relay_pattern.is_op("nn.batch_matmul")(reshape_data, weight) - reshape_out = relay_pattern.is_op("reshape")(batch_matmul) - return relay_pattern.is_op("squeeze")(reshape_out) - data = relay_pattern.wildcard() - weight = relay_pattern.is_constant() - trans_weight = relay_pattern.is_op("transpose")(weight) - broadcast_data = relay_pattern.is_op("broadcast_to")(data) - broadcast_weight = relay_pattern.is_op("broadcast_to")(trans_weight) - reshape_data = relay_pattern.is_op("reshape")(broadcast_data) - reshape_weight = relay_pattern.is_op("reshape")(broadcast_weight) - batch_matmul = relay_pattern.is_op("nn.batch_matmul")(reshape_data, reshape_weight) - reshape_out = relay_pattern.is_op("reshape")(batch_matmul) - return relay_pattern.is_op("squeeze")(reshape_out) - - def _check_relay_linear(call: tvm.relay.Expr) -> bool: - """Check if linear pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - return True - - def make_relay_linear_bias_pattern(optimized: bool = False) -> relay_pattern.DFPattern: - """A simple utility to create patterns for linear_bias. - - Parameters - ---------- - optimized: bool - Whether the relay is optimized - - Returns - ------- - pattern: DFPattern - The resulting pattern describing a linear_bias operation - """ - - bias = relay_pattern.is_constant() - linear = make_relay_linear_pattern(optimized) - if optimized: - out = relay_pattern.is_op("add")(linear, bias) - else: - out = relay_pattern.is_op("nn.bias_add")(linear, bias) - return out - - def _check_relay_linear_bias(call: tvm.relay.Expr) -> bool: - """Check if linear_bias pattern is correct.""" - return True - - def make_relay_matmul_pattern(dim: int = 2, optimized: bool = False) -> relay_pattern.DFPattern: - """A simple utility to create patterns for matmul. - - Parameters - ---------- - optimized: bool - Whether the relay is optimized - - Returns - ------- - pattern: tvm.relay.dataflow_pattern.DFPattern - The resulting pattern describing a matmul operation - """ - - if dim == 2: - a = relay_pattern.wildcard() - b = relay_pattern.wildcard() - trans_b = relay_pattern.is_op("transpose")(b) - dense = relay_pattern.is_op("nn.dense")(a, trans_b) - return dense | relay_pattern.is_op("squeeze")(dense) - elif dim == 3: - a = relay_pattern.wildcard() - b = relay_pattern.wildcard() - broadcast_a = relay_pattern.is_op("broadcast_to")(a) - broadcast_b = relay_pattern.is_op("broadcast_to")(b) - reshape_a = relay_pattern.is_op("reshape")(broadcast_a) - reshape_b = relay_pattern.is_op("reshape")(broadcast_b) - batch_matmul = relay_pattern.is_op("nn.batch_matmul")(reshape_a, reshape_b) - reshape_out = relay_pattern.is_op("reshape")(batch_matmul) - return relay_pattern.is_op("squeeze")(reshape_out) - else: - raise Exception("matmul pattern only support dim 2 and 3") - - def _check_relay_matmul(call: tvm.relay.Expr) -> bool: - """Check if matmul pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - last_call = call.args[0] if call.op.name == "squeeze" else call - if last_call.op.name == "nn.dense": - trans_b = last_call.args[1] - b = trans_b.args[0] - if len(b.checked_type.shape) != 2: - return False - return trans_b.attrs["axes"] is None or list(trans_b.attrs["axes"]) == [1, 0] - return True - - def make_relay_embedding_pattern(optimized: bool = False) -> relay_pattern.DFPattern: - """A simple utility to create patterns for 1d embedding. - - Returns - ------- - pattern: tvm.relay.dataflow_pattern.DFPattern - The resulting pattern describing a embedding operation - """ - - weight = relay_pattern.is_constant() - data = relay_pattern.wildcard() - astype = relay_pattern.is_op("cast")(data) - return relay_pattern.is_op("take")(weight, astype) - - def _check_relay_embedding(call: tvm.relay.Expr) -> bool: - """Check if embedding pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - weight = call.args[0] - cast = call.args[1] - return ( - cast.attrs["dtype"] == "int32" - and len(weight.checked_type.shape) == 2 - and weight.checked_type.dtype == "float32" - ) - - def make_relay_gelu_pattern(optimized: bool = False) -> relay_pattern.DFPattern: - """A simple utility to create patterns for gelu. - - Returns - ------- - pattern: tvm.relay.dataflow_pattern.DFPattern - The resulting pattern describing a gelu operation. - """ - - data = relay_pattern.wildcard() - factor_1 = relay_pattern.is_constant() - mul_1 = relay_pattern.is_op("multiply")(data, factor_1) - erf = relay_pattern.is_op("erf")(mul_1) - factor_2 = relay_pattern.is_constant() - mul_2 = relay_pattern.is_op("multiply")(erf, factor_2) - factor_3 = relay_pattern.is_constant() - add = relay_pattern.is_op("add")(factor_3, mul_2) - return relay_pattern.is_op("multiply")(data, add) - - def _check_relay_gelu(call: tvm.relay.Expr) -> bool: - """Check if gelu pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - return True - - return [ - ("msc.conv1d_bias", make_relay_conv_bias_pattern("nn.conv1d"), _check_relay_conv_bias), - ( - "msc.conv1d_bias", - make_relay_conv_bias_pattern("nn.conv1d", True), - _check_relay_conv_bias, - ), - ("msc.conv2d_bias", make_relay_conv_bias_pattern("nn.conv2d"), _check_relay_conv_bias), - ( - "msc.conv2d_bias", - make_relay_conv_bias_pattern("nn.conv2d", True), - _check_relay_conv_bias, - ), - ("msc.linear_bias", make_relay_linear_bias_pattern(), _check_relay_linear_bias), - ("msc.linear", make_relay_linear_pattern(), _check_relay_linear), - ("msc.linear", make_relay_linear_pattern(True), _check_relay_linear), - ("msc.matmul", make_relay_matmul_pattern(dim=2), _check_relay_matmul), - ("msc.matmul", make_relay_matmul_pattern(dim=3), _check_relay_matmul), - ("msc.embedding", make_relay_embedding_pattern(), _check_relay_embedding), - ("msc.gelu", make_relay_gelu_pattern(), _check_relay_gelu), - ] diff --git a/python/tvm/contrib/msc/core/transform/transform.py b/python/tvm/contrib/msc/core/transform/transform.py index e78b5cb71450..47ea21266eb0 100644 --- a/python/tvm/contrib/msc/core/transform/transform.py +++ b/python/tvm/contrib/msc/core/transform/transform.py @@ -20,13 +20,11 @@ from typing import Dict import tvm -from tvm.relax.transform import _ffi_api as relax_api -from tvm.relay.transform import _ffi_api as relay_api +from tvm.relax.transform import _ffi_api from tvm.contrib.msc.core import utils as msc_utils def SetExprName( - as_relax: bool = True, entry_name: str = "main", target: str = "", var_names: Dict[str, str] = None, @@ -35,8 +33,6 @@ def SetExprName( Parameters ---------- - as_relax: bool - Whether set names for relax, otherwise for relay. entry_name: str The entry name target: str @@ -49,33 +45,9 @@ def SetExprName( ret: tvm.ir.transform.Pass """ - if as_relax: - var_names = var_names or {} - var_names = {k: msc_utils.legalize_expr_name(v) for k, v in var_names.items()} - return relax_api.SetRelaxExprName(entry_name, target, var_names) # type: ignore - return relay_api.SetRelaxExprName(entry_name) # type: ignore - - -def BindExprName( - name_key: str = "", seperator: str = ",", entry_name: str = "main" -) -> tvm.ir.transform.Pass: - """Bind name for the call and constant in IRModule. - - Parameters - ---------- - name_key: str - The key to find name - seperator: str - The seperator - entry_name: str - The entry name - - Returns - ------- - ret: tvm.ir.transform.Pass - """ - - return relay_api.BindRelaxExprName(name_key, seperator, entry_name) # type: ignore + var_names = var_names or {} + var_names = {k: msc_utils.legalize_expr_name(v) for k, v in var_names.items()} + return _ffi_api.SetRelaxExprName(entry_name, target, var_names) # type: ignore def SetExprLayout(allow_missing: bool = True, entry_name: str = "main") -> tvm.ir.transform.Pass: @@ -93,7 +65,7 @@ def SetExprLayout(allow_missing: bool = True, entry_name: str = "main") -> tvm.i ret: tvm.ir.transform.Pass """ - return relax_api.SetExprLayout(allow_missing, entry_name) # type: ignore + return _ffi_api.SetExprLayout(allow_missing, entry_name) # type: ignore def InlineParams(entry_name: str = "main") -> tvm.ir.transform.Pass: @@ -109,7 +81,7 @@ def InlineParams(entry_name: str = "main") -> tvm.ir.transform.Pass: ret: tvm.ir.transform.Pass """ - return relax_api.InlineParams(entry_name) # type: ignore + return _ffi_api.InlineParams(entry_name) # type: ignore def FuseTuple(target, entry_name: str = "main") -> tvm.ir.transform.Pass: @@ -127,7 +99,7 @@ def FuseTuple(target, entry_name: str = "main") -> tvm.ir.transform.Pass: ret: tvm.ir.transform.Pass """ - return relax_api.FuseTuple(target, entry_name) # type: ignore + return _ffi_api.FuseTuple(target, entry_name) # type: ignore def SetBYOCAttrs(target, entry_name: str = "main") -> tvm.ir.transform.Pass: @@ -145,7 +117,7 @@ def SetBYOCAttrs(target, entry_name: str = "main") -> tvm.ir.transform.Pass: ret: tvm.ir.transform.Pass """ - return relax_api.SetBYOCAttrs(target, entry_name) # type: ignore + return _ffi_api.SetBYOCAttrs(target, entry_name) # type: ignore def BindNamedParams( @@ -167,4 +139,4 @@ def BindNamedParams( ret: tvm.ir.transform.Pass """ - return relax_api.BindNamedParams(func_name, params) # type: ignore + return _ffi_api.BindNamedParams(func_name, params) # type: ignore diff --git a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py b/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py index dab19ca81f83..6eb13b6ce076 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py +++ b/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +# pylint: disable=unused-argument """tvm.contrib.msc.framework.torch.frontend.translate""" from typing import Dict, Optional, Tuple, List, Union @@ -21,18 +23,13 @@ import tvm from tvm.contrib.msc.core.ir.graph import MSCGraph -from tvm.contrib.msc.core import transform as msc_transform -from tvm.contrib.msc.core.frontend import from_relax -from tvm.contrib.msc.core.codegen import relay_to_relax from tvm.contrib.msc.framework.tensorflow import tf_v1 -from tvm.contrib.msc.core import utils as msc_utils def from_tensorflow( graph_def: tf_v1.GraphDef, shape_dict: Dict[str, List[int]], outputs: List[str], - via_relax: bool = False, trans_config: Optional[Dict[str, str]] = None, build_config: Optional[Dict[str, str]] = None, opt_config: Optional[Dict[str, str]] = None, @@ -48,8 +45,6 @@ def from_tensorflow( The shape dict of inputs. outputs: list The output names. - via_relax: bool - Whether translate torch to relax. trans_config: dict The config for transform IRModule. build_config: dict @@ -67,16 +62,4 @@ def from_tensorflow( The weights from the IRModule. """ - assert not via_relax, "Relax frontend for tensorflow is not supported" - relay_mod, params = tvm.relay.frontend.from_tensorflow( - graph_def, shape=shape_dict, outputs=outputs - ) - passes = [msc_transform.BindExprName()] - relay_mod = tvm.transform.Sequential(passes)(relay_mod) - relax_mod = relay_to_relax(relay_mod, params, trans_config, build_config, opt_config) - if not as_msc: - return relax_mod, params - build_config = msc_utils.copy_dict(build_config) - build_config["use_var_name"] = True - graph, weights = from_relax(relax_mod, trans_config=trans_config, build_config=build_config) - return graph, weights + raise NotImplementedError("translate relax module from tensorflow is not implemented") diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py b/python/tvm/contrib/msc/framework/torch/frontend/translate.py index 04597bd3419b..32626b904c30 100644 --- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py +++ b/python/tvm/contrib/msc/framework/torch/frontend/translate.py @@ -17,15 +17,12 @@ """tvm.contrib.msc.framework.torch.frontend.translate""" from typing import Dict, Optional, Tuple, List, Union -import numpy as np import torch import tvm from tvm.relax.frontend.torch import from_fx from tvm.contrib.msc.core.ir.graph import MSCGraph from tvm.contrib.msc.core.frontend import from_relax, normalize_inputs -from tvm.contrib.msc.core.codegen import relay_to_relax -from tvm.contrib.msc.core import utils as msc_utils def set_weight_alias(graph: MSCGraph) -> MSCGraph: @@ -36,6 +33,7 @@ def set_weight_alias(graph: MSCGraph) -> MSCGraph: graph: MSCGraph The graph. + Returns ------- graph: MSCGraph @@ -64,14 +62,10 @@ def set_weight_alias(graph: MSCGraph) -> MSCGraph: def from_torch( model: torch.nn.Module, input_info: List[Tuple[Tuple[int], str]], - input_names: List[str] = None, - via_relax: bool = True, trans_config: Optional[Dict[str, str]] = None, build_config: Optional[Dict[str, str]] = None, - opt_config: Optional[Dict[str, str]] = None, as_msc: bool = True, custom_convert_map: dict = None, - build_folder: msc_utils.MSCDirectory = None, ) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]: """Change torch nn.Module to MSCGraph. @@ -83,8 +77,6 @@ def from_torch( The input info in format [(shape, dtype)]. input_names: list The input names. - via_relax: bool - Whether translate torch to relax. trans_config: dict The config for transform IRModule. build_config: dict @@ -106,35 +98,10 @@ def from_torch( The weights from the IRModule. """ - # try to symbolic_trace - if via_relax: - try: - graph_model = torch.fx.symbolic_trace(model) - except: # pylint: disable=bare-except - via_relax = False - - if via_relax: - input_info, params = normalize_inputs(input_info), None - with torch.no_grad(): - relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map) - else: - datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info] - torch_datas = [torch.from_numpy(i) for i in datas] - with torch.no_grad(): - scripted_model = torch.jit.trace(model, tuple(torch_datas)).eval() - if input_names: - assert len(input_names) == len( - input_info - ), "input_names {} length mismatch with input_info {}".format(input_names, input_info) - shape_list = list(zip(input_names, input_info)) - else: - shape_list = [("input" + str(idx), i_info) for idx, i_info in enumerate(input_info)] - relay_mod, params = tvm.relay.frontend.from_pytorch( - scripted_model, shape_list, custom_convert_map=custom_convert_map - ) - relax_mod = relay_to_relax( - relay_mod, params, trans_config, build_config, opt_config, build_folder=build_folder - ) + graph_model = torch.fx.symbolic_trace(model) + input_info, params = normalize_inputs(input_info), None + with torch.no_grad(): + relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map) if not as_msc: return relax_mod, params graph, weights = from_relax(relax_mod, trans_config=trans_config, build_config=build_config) diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index 1cc0c4af6a3b..2244562a2f46 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -30,6 +30,8 @@ namespace tvm { namespace contrib { namespace msc { +using namespace tvm::relax; + const std::string GetScalarStr(const runtime::NDArray& data, int float_precision) { std::string scalar_str; if (data->dtype.code == kDLFloat) { @@ -44,7 +46,7 @@ const std::string GetScalarStr(const runtime::NDArray& data, int float_precision return scalar_str; } -void RelaxFuncAttrGetter::VisitExpr_(const relax::CallNode* op) { +void FuncAttrGetter::VisitExpr_(const CallNode* op) { if (op->attrs.defined()) { Map attrs; AttrGetter getter(&attrs); @@ -64,18 +66,17 @@ void RelaxFuncAttrGetter::VisitExpr_(const relax::CallNode* op) { } } -void RelaxFuncAttrGetter::VisitExpr_(const relax::TupleGetItemNode* op) { +void FuncAttrGetter::VisitExpr_(const TupleGetItemNode* op) { attrs_.Set("index", std::to_string(op->index)); } -void RelaxFuncValueGetter::VisitExpr_(const relax::CallNode* op) { +void FuncValueGetter::VisitExpr_(const CallNode* op) { for (const auto& arg : op->args) { - if (const auto* s_node = arg.as()) { + if (const auto* s_node = arg.as()) { values_.push_back(StringUtils::ToString(s_node->value)); - } else if (const auto* s_node = arg.as()) { - bool all_values = - std::all_of(s_node->fields.begin(), s_node->fields.end(), - [](const relax::Expr& e) { return e->IsInstance(); }); + } else if (const auto* s_node = arg.as()) { + bool all_values = std::all_of(s_node->fields.begin(), s_node->fields.end(), + [](const Expr& e) { return e->IsInstance(); }); if (all_values) { values_.push_back(StringUtils::ToString(s_node->fields)); } @@ -83,25 +84,24 @@ void RelaxFuncValueGetter::VisitExpr_(const relax::CallNode* op) { } } -void RelaxFuncParamsFinder::VisitBinding_(const relax::VarBindingNode* binding, - const relax::FunctionNode* val) { - local_funcs_.Set(binding->var, GetRef(val)); +void FuncParamsFinder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { + local_funcs_.Set(binding->var, GetRef(val)); } -void RelaxFuncParamsFinder::VisitExpr_(const relax::CallNode* call_node) { - RelaxExprVisitor::VisitExpr_(call_node); - relax::Function func; +void FuncParamsFinder::VisitExpr_(const CallNode* call_node) { + ExprVisitor::VisitExpr_(call_node); + Function func; if (const auto* v_node = call_node->op.as()) { - func = Downcast(ref_module_->Lookup(v_node->name_hint)); - } else if (call_node->op->IsInstance()) { + func = Downcast(ref_module_->Lookup(v_node->name_hint)); + } else if (call_node->op->IsInstance()) { ICHECK(local_funcs_.count(call_node->op)) << "Can not find local func " << call_node->op; func = local_funcs_[call_node->op]; } if (func.defined()) { for (size_t i = 0; i < call_node->args.size(); i++) { const auto& arg = call_node->args[i]; - if (arg->IsInstance() && params_.count(Downcast(arg))) { - params_.Set(func->params[i], params_[Downcast(arg)]); + if (arg->IsInstance() && params_.count(Downcast(arg))) { + params_.Set(func->params[i], params_[Downcast(arg)]); } else { params_.Set(func->params[i], arg); } @@ -109,18 +109,17 @@ void RelaxFuncParamsFinder::VisitExpr_(const relax::CallNode* call_node) { } } -void RelaxLayoutsFinder::VisitBinding_(const relax::VarBindingNode* binding, - const relax::FunctionNode* val) { - local_funcs_.Set(binding->var, GetRef(val)); +void LayoutsFinder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { + local_funcs_.Set(binding->var, GetRef(val)); } -void RelaxLayoutsFinder::VisitExpr_(const relax::CallNode* call_node) { - RelaxExprVisitor::VisitExpr_(call_node); - relax::Function func; +void LayoutsFinder::VisitExpr_(const CallNode* call_node) { + ExprVisitor::VisitExpr_(call_node); + Function func; if (const auto* v_node = call_node->op.as()) { - func = Downcast(ref_module_->Lookup(v_node->name_hint)); + func = Downcast(ref_module_->Lookup(v_node->name_hint)); VisitExpr(func); - } else if (call_node->op->IsInstance()) { + } else if (call_node->op->IsInstance()) { ICHECK(local_funcs_.count(call_node->op)) << "Can not find local func " << call_node->op; func = local_funcs_[call_node->op]; } @@ -134,7 +133,7 @@ void RelaxLayoutsFinder::VisitExpr_(const relax::CallNode* call_node) { } } -const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { +const MSCGraph GraphBuilder::Build(const Function& func) { // Add input nodes and record inputs; Array input_names, output_names; std::set added_inputs; @@ -143,7 +142,7 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { if (!p->struct_info_.defined()) { continue; } - if (p->struct_info_.value()->IsInstance()) { + if (p->struct_info_.value()->IsInstance()) { const auto& shape = ExprUtils::GetShape(p, false); for (size_t i = 0; i < shape.size(); i++) { if (shape[i]->IsInstance()) { @@ -163,16 +162,16 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { if (expr_tensor_map_.count(p)) { continue; } - if (func_params_.count(p) && func_params_[p]->IsInstance()) { + if (func_params_.count(p) && func_params_[p]->IsInstance()) { continue; } - if (func_params_.count(p) && func_params_[p]->IsInstance()) { - const auto& tuple = Downcast(func_params_[p]); + if (func_params_.count(p) && func_params_[p]->IsInstance()) { + const auto& tuple = Downcast(func_params_[p]); Array tuple_names; for (const auto& f : tuple->fields) { if (expr_tensor_map_.count(f)) { LOG_INFO << "Replica tuple input " << f; - } else if (const auto* f_node = f.as()) { + } else if (const auto* f_node = f.as()) { AddNode(f, NullOpt, f_node->name_hint()); } else { LOG_FATAL << "Unexpected tuple input " << f << "(" << f->GetTypeKey() << ")"; @@ -254,40 +253,40 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { return graph; } -const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional& binding_var, - const String& name) { +const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& binding_var, + const String& name) { // Get optype, node_name and layout String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, msc_attr::kName); String optype = "unknown"; String layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); - if (func_params_.count(expr) && func_params_[expr]->IsInstance()) { + if (func_params_.count(expr) && func_params_[expr]->IsInstance()) { node_name = SpanUtils::GetAttr(func_params_[expr]->span, msc_attr::kName); optype = "constant"; - } else if (expr->IsInstance()) { + } else if (expr->IsInstance()) { optype = "input"; - } else if (expr->IsInstance()) { + } else if (expr->IsInstance()) { optype = "constant"; - } else if (expr->IsInstance()) { + } else if (expr->IsInstance()) { optype = "shape"; - } else if (expr->IsInstance()) { + } else if (expr->IsInstance()) { optype = "get_item"; - } else if (expr->IsInstance()) { + } else if (expr->IsInstance()) { optype = "tuple"; - } else if (const auto* call_node = expr.as()) { + } else if (const auto* call_node = expr.as()) { if (const auto* op_node = call_node->op.as()) { if (op_node->name == "relax.call_dps_packed") { - optype = Downcast(call_node->args[0])->global_symbol; + optype = Downcast(call_node->args[0])->global_symbol; } else { optype = StringUtils::Replace(op_node->name, "relax.", ""); } } else if (const auto* v_node = call_node->op.as()) { - const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); + const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); std::tie(node_name, optype, layout) = ParseFunc(func); - } else if (call_node->op->IsInstance()) { + } else if (call_node->op->IsInstance()) { ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op; std::tie(node_name, optype, layout) = ParseFunc(target_funcs_[call_node->op]); - } else if (call_node->op->IsInstance()) { - std::tie(node_name, optype, layout) = ParseFunc(Downcast(call_node->op)); + } else if (call_node->op->IsInstance()) { + std::tie(node_name, optype, layout) = ParseFunc(Downcast(call_node->op)); } } if (layouts_.count(node_name)) { @@ -295,9 +294,9 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } // specail case for tuple - if (optype == "tuple" && expr->IsInstance() && - Downcast(expr)->op->IsInstance()) { - const auto& call_node = Downcast(expr); + if (optype == "tuple" && expr->IsInstance() && + Downcast(expr)->op->IsInstance()) { + const auto& call_node = Downcast(expr); ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op; const auto& tuple_func = target_funcs_[call_node->op]; for (size_t i = 0; i < call_node->args.size(); i++) { @@ -319,7 +318,7 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional // Extract normal attributes Map attrs; if (plugin.defined()) { - const auto& op = Downcast(expr)->op; + const auto& op = Downcast(expr)->op; if (target_funcs_.count(op)) { const auto& opattrs_opt = target_funcs_[op]->GetAttr>(msc_attr::kOpattrs); if (opattrs_opt.defined()) { @@ -337,55 +336,55 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional attrs.Set(plugin->attrs[i]->name, StringUtils::ToString(val)); } } - } else if (const auto* call_node = expr.as()) { + } else if (const auto* call_node = expr.as()) { if (const auto* v_node = call_node->op.as()) { - const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); + const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); const auto& name_opt = func->GetAttr(relax::attr::kComposite); if (name_opt.defined()) { - attrs = RelaxFuncAttrGetter().GetAttrs(func); + attrs = FuncAttrGetter().GetAttrs(func); } - } else if (call_node->op->IsInstance()) { + } else if (call_node->op->IsInstance()) { ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op; - attrs = RelaxFuncAttrGetter().GetAttrs(target_funcs_[call_node->op]); - } else if (call_node->op->IsInstance()) { - attrs = RelaxFuncAttrGetter().GetAttrs(call_node->op); + attrs = FuncAttrGetter().GetAttrs(target_funcs_[call_node->op]); + } else if (call_node->op->IsInstance()) { + attrs = FuncAttrGetter().GetAttrs(call_node->op); } else if (call_node->attrs.defined()) { AttrGetter getter(&attrs); const_cast(call_node->attrs.get())->VisitAttrs(&getter); } - } else if (const auto* const_node = expr.as()) { + } else if (const auto* const_node = expr.as()) { if (const_node->is_scalar()) { attrs.Set("scalar", GetScalarStr(const_node->data, config_.float_precision)); } - } else if (const auto* shape_node = expr.as()) { + } else if (const auto* shape_node = expr.as()) { attrs.Set("shape", StringUtils::ToString(shape_node->values)); - } else if (const auto* get_node = expr.as()) { + } else if (const auto* get_node = expr.as()) { attrs.Set("index", std::to_string(get_node->index)); } // Extract attributes from arguments Array input_types; - if (!plugin.defined() && expr->IsInstance()) { - const auto& call = Downcast(expr); + if (!plugin.defined() && expr->IsInstance()) { + const auto& call = Downcast(expr); Array values; - if (call->op->IsInstance()) { + if (call->op->IsInstance()) { ICHECK(target_funcs_.count(call->op)) << "Can not find target func: " << call->op; - values = RelaxFuncValueGetter().GetValues(target_funcs_[call->op]); + values = FuncValueGetter().GetValues(target_funcs_[call->op]); } input_types = ExprUtils::GetInputTypes(optype, call->args.size() + values.size(), true); for (size_t i = 0; i < call->args.size(); i++) { const auto& arg = call->args[i]; - if (const auto* s_node = arg.as()) { + if (const auto* s_node = arg.as()) { attrs.Set(input_types[i], StringUtils::ToString(s_node->values)); - } else if (func_params_.count(arg) && func_params_[arg]->IsInstance()) { - const auto* s_node = func_params_[arg].as(); + } else if (func_params_.count(arg) && func_params_[arg]->IsInstance()) { + const auto* s_node = func_params_[arg].as(); attrs.Set(input_types[i], StringUtils::ToString(s_node->values)); - ignore_nodes_.insert(Downcast(arg)->name_hint()); - } else if (const auto* s_node = arg.as()) { + ignore_nodes_.insert(Downcast(arg)->name_hint()); + } else if (const auto* s_node = arg.as()) { ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype << " should has special type, get " << input_types; attrs.Set(input_types[i], StringUtils::ToString(s_node->value)); - } else if (input_types[i] != "input" && arg->IsInstance()) { + } else if (input_types[i] != "input" && arg->IsInstance()) { attrs.Set(input_types[i], StringUtils::ToString(arg)); } } @@ -398,7 +397,7 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional Array input_names; Map node_weights; if (plugin.defined()) { - const auto& call = Downcast(expr); + const auto& call = Downcast(expr); if (call->args.size() == 1) { ICHECK(expr_tensor_map_.count(call->args[0])) << "Can not find tuple plugin input " << call->args[0]; @@ -412,7 +411,7 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } } } - } else if (const auto* call_node = expr.as()) { + } else if (const auto* call_node = expr.as()) { for (size_t i = 0; i < call_node->args.size(); i++) { if (attrs.count(input_types[i])) { continue; @@ -421,8 +420,8 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional Array arg_names; if (expr_tensor_map_.count(arg)) { arg_names = expr_tensor_map_[arg]; - } else if (input_types[i] == "input" && arg->IsInstance()) { - const auto* tuple_node = arg.as(); + } else if (input_types[i] == "input" && arg->IsInstance()) { + const auto* tuple_node = arg.as(); for (const auto& f : tuple_node->fields) { ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; for (const auto& in_name : expr_tensor_map_[f]) { @@ -431,12 +430,12 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } } String weight_name; - if (input_types[i] != "input" && arg->IsInstance()) { + if (input_types[i] != "input" && arg->IsInstance()) { weight_name = SpanUtils::GetAttr(arg->span, msc_attr::kName); } else if (input_types[i] != "input" && func_params_.count(arg) && - func_params_[arg]->IsInstance()) { + func_params_[arg]->IsInstance()) { weight_name = SpanUtils::GetAttr(func_params_[arg]->span, msc_attr::kName); - ignore_nodes_.insert(Downcast(arg)->name_hint()); + ignore_nodes_.insert(Downcast(arg)->name_hint()); } // set weights or inputs if (weight_name.size() > 0) { @@ -472,19 +471,19 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } } } - } else if (const auto* tuple_node = expr.as()) { + } else if (const auto* tuple_node = expr.as()) { for (const auto& f : tuple_node->fields) { ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; for (const auto& in_name : expr_tensor_map_[f]) { input_names.push_back(in_name); } } - } else if (const auto* getitem_node = expr.as()) { + } else if (const auto* getitem_node = expr.as()) { ICHECK(expr_tensor_map_.count(getitem_node->tuple)) << "Can not find tuple " << getitem_node->tuple; input_names = expr_tensor_map_[getitem_node->tuple]; } else if (optype == "constant") { - const auto& t_info = Downcast(relax::GetStructInfo(expr)); + const auto& t_info = Downcast(GetStructInfo(expr)); const auto& shape_opt = t_info->GetShape(); ICHECK(shape_opt.defined()) << "Constant shape is not defined"; const auto& weight = @@ -511,11 +510,11 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } // Build output tensor - auto build_output = [this](const relax::StructInfo& sinfo, const String& node_name, + auto build_output = [this](const StructInfo& sinfo, const String& node_name, const String& layout) { - ICHECK(sinfo->IsInstance()) + ICHECK(sinfo->IsInstance()) << "sinfo should be TensorStructInfo, get " << sinfo->GetTypeKey(); - const auto& t_info = Downcast(sinfo); + const auto& t_info = Downcast(sinfo); const auto& shape = ArrayUtils::Cast(ExprUtils::GetShape(t_info)); Array prims; bool has_prims = false; @@ -537,10 +536,10 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional // Gather outputs Array outputs; - const auto& sinfo = relax::GetStructInfo(expr); + const auto& sinfo = GetStructInfo(expr); Array layouts = StringUtils::Split(layout, ","); size_t num_output = 1; - if (const auto* tuple_sinfo = sinfo.as()) { + if (const auto* tuple_sinfo = sinfo.as()) { num_output = tuple_sinfo->fields.size(); } if (layouts.size() == 0) { @@ -548,15 +547,15 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } ICHECK_EQ(layouts.size(), num_output) << "Layouts " << layouts << " msimatch with output size " << num_output; - if (sinfo->IsInstance()) { + if (sinfo->IsInstance()) { const auto& t_name = node_name + ":" + std::to_string(0); outputs.push_back(build_output(sinfo, t_name, layouts[0])); - } else if (const auto* s_sinfo = sinfo.as()) { + } else if (const auto* s_sinfo = sinfo.as()) { Array shape{s_sinfo->ndim}; const auto& t_name = node_name + ":" + std::to_string(0); const auto& dtype = DataType(runtime::String2DLDataType("int32")); outputs.push_back(MSCTensor(t_name, dtype, layouts[0], shape)); - } else if (const auto* tuple_sinfo = sinfo.as()) { + } else if (const auto* tuple_sinfo = sinfo.as()) { size_t field_size = optype == "nn.batch_norm" ? 1 : num_output; for (size_t i = 0; i < field_size; i++) { const auto& t_name = node_name + ":" + std::to_string(i); @@ -585,7 +584,7 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional return node; } -void RelaxGraphBuilder::VisitBindingBlock(const relax::BindingBlock& block) { +void GraphBuilder::VisitBindingBlock(const BindingBlock& block) { String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); if (block_name.size() == 0) { block_name = "block"; @@ -601,7 +600,7 @@ void RelaxGraphBuilder::VisitBindingBlock(const relax::BindingBlock& block) { scope_name_ = prefix + "." + block_name; setted_blocks_.insert(scope_name_); block_stack_.push_back(block_name); - RelaxExprVisitor::VisitBindingBlock(block); + ExprVisitor::VisitBindingBlock(block); block_stack_.pop_back(); } @@ -611,7 +610,7 @@ void RelaxGraphBuilder::VisitBindingBlock(const relax::BindingBlock& block) { return MatchOrCreatePrim(prim, "", {AddPrim(binary->a), AddPrim(binary->b)}); \ } -const MSCPrim RelaxGraphBuilder::AddPrim(const PrimExpr& prim) { +const MSCPrim GraphBuilder::AddPrim(const PrimExpr& prim) { if (prim_map_.count(prim)) { return prim_map_[prim]; } @@ -659,9 +658,9 @@ const MSCPrim RelaxGraphBuilder::AddPrim(const PrimExpr& prim) { return MatchOrCreatePrim(prim); } -const MSCPrim RelaxGraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const String& optype, - const Array& parents, - const Map& attrs) { +const MSCPrim GraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const String& optype, + const Array& parents, + const Map& attrs) { if (prim_map_.count(prim)) { return prim_map_[prim]; } @@ -703,30 +702,27 @@ const MSCPrim RelaxGraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const S return node; } -void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) { - if (!expr_tensor_map_.count(GetRef(op))) { - AddNode(GetRef(op)); +void GraphBuilder::VisitExpr_(const ConstantNode* op) { + if (!expr_tensor_map_.count(GetRef(op))) { + AddNode(GetRef(op)); } } -void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, - const relax::ConstantNode* val) { +void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) { const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + AddNode(GetRef(val), binding->var, name); } -void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, - const relax::ShapeExprNode* val) { +void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) { const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + AddNode(GetRef(val), binding->var, name); } -void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, - const relax::CallNode* call_node) { - RelaxExprVisitor::VisitBinding_(binding, call_node); +void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) { + ExprVisitor::VisitBinding_(binding, call_node); const String& name = config_.use_var_name ? binding->var->name_hint() : ""; try { - AddNode(GetRef(call_node), binding->var, name); + AddNode(GetRef(call_node), binding->var, name); } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to add node from " << binding->var << " : " << binding->value << ", reason: " << err.message(); @@ -734,46 +730,41 @@ void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, } } -void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, - const relax::TupleNode* val) { - RelaxExprVisitor::VisitBinding_(binding, val); +void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { + ExprVisitor::VisitBinding_(binding, val); const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + AddNode(GetRef(val), binding->var, name); } -void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, - const relax::TupleGetItemNode* val) { - RelaxExprVisitor::VisitBinding_(binding, val); +void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { + ExprVisitor::VisitBinding_(binding, val); const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + AddNode(GetRef(val), binding->var, name); } -void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, - const relax::VarNode* val) { - RelaxExprVisitor::VisitBinding_(binding, val); - const auto& output = GetRef(val); +void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const VarNode* val) { + ExprVisitor::VisitBinding_(binding, val); + const auto& output = GetRef(val); ICHECK(expr_tensor_map_.count(output)) << "Can not find var " << output; expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]); } -void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, - const relax::DataflowVarNode* val) { - RelaxExprVisitor::VisitBinding_(binding, val); - const auto& output = GetRef(val); +void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) { + ExprVisitor::VisitBinding_(binding, val); + const auto& output = GetRef(val); ICHECK(expr_tensor_map_.count(output)) << "Can not find dataflow var " << output; expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]); } -void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, - const relax::FunctionNode* val) { - const auto& name_opt = val->GetAttr(relay::attr::kComposite); +void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { + const auto& name_opt = val->GetAttr(relax::attr::kComposite); ICHECK(name_opt.defined()) << "Unexpected target func without composite"; ICHECK(config_.target.size() > 0 && StringUtils::StartsWith(name_opt.value(), config_.target)) << "Target should be given for target function"; - target_funcs_.Set(binding->var, GetRef(val)); + target_funcs_.Set(binding->var, GetRef(val)); } -const std::tuple RelaxGraphBuilder::ParseFunc(const relax::Function& func) { +const std::tuple GraphBuilder::ParseFunc(const Function& func) { String node_name, optype, layout; const auto& name_opt = func->GetAttr(msc_attr::kUnique); // get node_name @@ -802,32 +793,32 @@ const std::tuple RelaxGraphBuilder::ParseFunc(const rela return std::make_tuple(node_name, optype, layout); } -void RelaxGraphBuilder::VisitPrimExpr(const PrimExpr& prim) { - RelaxExprVisitor::VisitPrimExpr(prim); +void GraphBuilder::VisitPrimExpr(const PrimExpr& prim) { + ExprVisitor::VisitPrimExpr(prim); if (!prim->IsInstance() && !prim->IsInstance()) { AddPrim(prim); } } -Array RelaxGraphBuilder::GetPluginInputs(const relax::Expr& expr) { - ICHECK(expr->IsInstance()) << "plugin expr should be call"; - const auto& call = Downcast(expr); - ICHECK(call->args[1]->IsInstance()) << "plugin argument 1 should be call"; - return Downcast(call->args[1])->fields; +Array GraphBuilder::GetPluginInputs(const Expr& expr) { + ICHECK(expr->IsInstance()) << "plugin expr should be call"; + const auto& call = Downcast(expr); + ICHECK(call->args[1]->IsInstance()) << "plugin argument 1 should be call"; + return Downcast(call->args[1])->fields; } -Map RelaxWeightsExtractor::GetWeights(const relax::Function& func) { +Map WeightsExtractor::GetWeights(const Function& func) { VisitExpr(func); return weights_; } -void RelaxWeightsExtractor::VisitExpr_(const relax::ConstantNode* op) { +void WeightsExtractor::VisitExpr_(const ConstantNode* op) { const auto& name = SpanUtils::GetAttr(op->span, msc_attr::kName); const auto& layout = SpanUtils::GetAttr(op->span, msc_attr::kLayout); - const auto& sinfo = relax::GetStructInfo(GetRef(op)); - ICHECK(sinfo->IsInstance()) + const auto& sinfo = GetStructInfo(GetRef(op)); + ICHECK(sinfo->IsInstance()) << "Constant StrcutInfo should be TensorStructInfo"; - const auto& t_info = Downcast(sinfo); + const auto& t_info = Downcast(sinfo); const auto& opt_shape = t_info->GetShape(); const auto& shape = opt_shape.defined() ? ArrayUtils::Cast(opt_shape.value()) : Array(); @@ -835,378 +826,29 @@ void RelaxWeightsExtractor::VisitExpr_(const relax::ConstantNode* op) { weights_.Set(weight, op->data); } -void RelaxWeightsExtractor::VisitExpr_(const relax::CallNode* op) { - RelaxExprVisitor::VisitExpr_(op); +void WeightsExtractor::VisitExpr_(const CallNode* op) { + ExprVisitor::VisitExpr_(op); if (const auto* v_node = op->op.as()) { - const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); + const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); VisitExpr(func); } } -void RelayFuncAttrGetter::VisitExpr_(const relay::CallNode* op) { - RelaxExprVisitor::VisitExpr_(op); - if (op->attrs.defined()) { - Map attrs; - AttrGetter getter(&attrs); - const_cast(op->attrs.get())->VisitAttrs(&getter); - for (const auto& pair : attrs) { - if (attrs_.count(pair.first)) { - int cnt = 1; - String rep_key = pair.first; - while (attrs_.count(rep_key + "_" + std::to_string(cnt))) { - cnt++; - } - attrs_.Set(pair.first + "_" + std::to_string(cnt), pair.second); - } else { - attrs_.Set(pair.first, pair.second); - } - } - } -} - -MSCGraph RelayGraphBuilder::Build(const relay::Function& func) { - // Add input nodes and record inputs; - Array input_names, output_names; - for (const auto& p : func->params) { - AddNode(p, p->name_hint()); - ICHECK(expr_tensor_map_.count(p)) << "Can not find func param " << p; - input_names.push_back(expr_tensor_map_[p][0]); - } - VisitExpr(func); - ICHECK(expr_tensor_map_.count(func->body)) << "Can not find func body " << func->body; - output_names = expr_tensor_map_[func->body]; - // remove const nodes as weights - Array valid_nodes; - for (const auto& n : nodes_) { - if (!weights_.count(n->name)) { - n->index = valid_nodes.size(); - valid_nodes.push_back(n); - } - } - const auto& graph = MSCGraph(name_, valid_nodes, input_names, output_names); - // set inputs and outputs alias - if (config_.input_aliases.size() == input_names.size()) { - for (size_t i = 0; i < input_names.size(); i++) { - graph->FindTensor(input_names[i])->alias = config_.input_aliases[i]; - } - } else { - for (size_t i = 0; i < input_names.size(); i++) { - graph->FindTensor(input_names[i])->alias = graph->FindProducer(input_names[i])->name; - } - } - if (config_.output_aliases.size() == output_names.size()) { - for (size_t i = 0; i < output_names.size(); i++) { - graph->FindTensor(output_names[i])->alias = config_.output_aliases[i]; - } - } else { - for (size_t i = 0; i < output_names.size(); i++) { - const auto& output = graph->FindTensor(output_names[i]); - if (output->alias.size() > 0) { - continue; - } - const auto& producer = graph->FindProducer(output_names[i]); - output->alias = producer->outputs.size() == 1 - ? producer->name - : StringUtils::Replace(output_names[i], ":", "_"); - } - } - return graph; -} - -MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const String& name) { - const auto& node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, msc_attr::kName); - const auto& shared_ref = SpanUtils::GetAttr(expr->span, msc_attr::kSharedRef); - - // Get optype - String optype; - if (expr->IsInstance()) { - optype = "input"; - } else if (expr->IsInstance()) { - optype = "constant"; - } else if (expr->IsInstance()) { - optype = "get_item"; - } else if (expr->IsInstance()) { - optype = "tuple"; - } else if (const auto* call_node = expr.as()) { - if (const auto* op_node = call_node->op.as()) { - optype = StringUtils::Replace(op_node->name, "relay.", ""); - } else { - optype = "unknown_op"; - } - } else if (const auto* f_node = expr.as()) { - const auto& name_opt = f_node->GetAttr(relay::attr::kComposite); - ICHECK(name_opt.defined()) << "Unexpected func without composite"; - optype = name_opt.value(); - } else { - optype = "unknown_expr"; - } - - // Extract attributes - Map attrs; - if (const auto* call_node = expr.as()) { - if (call_node->attrs.defined()) { - AttrGetter getter(&attrs); - const_cast(call_node->attrs.get())->VisitAttrs(&getter); - } - } else if (expr->IsInstance()) { - attrs = RelayFuncAttrGetter().GetAttrs(expr); - } else if (const auto* const_node = expr.as()) { - if (const_node->is_scalar()) { - attrs.Set("scalar", GetScalarStr(const_node->data, config_.float_precision)); - } - } else if (const auto* get_node = expr.as()) { - attrs.Set("index", std::to_string(get_node->index)); - } - - // Get scope - Array scope; - if (optype != "input" && optype != "constant") { - scope.push_back("block"); - } - - // Build inputs and weights - Array input_names; - Map node_weights; - if (const auto* call_node = expr.as()) { - const auto& input_types = ExprUtils::GetInputTypes(optype, call_node->args.size(), false); - for (size_t i = 0; i < call_node->args.size(); i++) { - const auto& arg = call_node->args[i]; - ICHECK(expr_tensor_map_.count(arg)) << "Missing argument " << arg; - if (input_types[i] != "input" && arg->IsInstance()) { - const auto& t_name = expr_tensor_map_[arg][0]; - const auto& weight_name = SpanUtils::GetAttr(arg->span, msc_attr::kName); - const auto& pair = tensor_input_map_[t_name]; - const auto& producer = Downcast(pair.first); - if (!weights_.count(weight_name)) { - const auto& ref = producer->OutputAt(pair.second); - MSCTensor weight; - if (input_types[i] == "bias") { - weight = MSCTensor(weight_name, ref->dtype, "O", Array{ref->GetSize()}); - } else { - weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape); - } - weights_.Set(weight_name, weight); - } - if (producer->HasAttr("scalar")) { - attrs.Set(input_types[i], producer->GetTypeAttr("scalar")); - } - node_weights.Set(input_types[i], weights_[weight_name]); - } else { - for (const auto& in_name : expr_tensor_map_[arg]) { - input_names.push_back(in_name); - } - } - } - } else if (const auto* f_node = expr.as()) { - for (const auto& p : f_node->params) { - for (const auto& in_name : expr_tensor_map_[p]) { - input_names.push_back(in_name); - } - } - ICHECK(HasFuncScope()) << "Function without func scope " << relay::PrettyPrint(expr); - const auto& weight_names = func_scopes_.top().GetFuncWeights(); - const auto& input_types = - ExprUtils::GetInputTypes(optype, f_node->params.size() + weight_names.size(), false); - for (size_t i = 0; i < weight_names.size(); i++) { - const auto& pair = tensor_input_map_[weight_names[i]]; - const auto& producer = Downcast(pair.first); - if (!weights_.count(producer->name)) { - const auto& ref = producer->OutputAt(pair.second); - const auto& weight = MSCTensor(producer->name, ref->dtype, ref->layout.name(), ref->shape); - weights_.Set(producer->name, weight); - } - if (producer->HasAttr("scalar")) { - attrs.Set(input_types[i], producer->GetTypeAttr("scalar")); - } - node_weights.Set(input_types[i + f_node->params.size()], weights_[producer->name]); - } - } else if (const auto* tuple_node = expr.as()) { - for (const auto& f : tuple_node->fields) { - ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; - for (const auto& in_name : expr_tensor_map_[f]) { - input_names.push_back(in_name); - } - } - } else if (const auto* getitem_node = expr.as()) { - ICHECK(expr_tensor_map_.count(getitem_node->tuple)) - << "Can not find tuple " << getitem_node->tuple; - input_names = expr_tensor_map_[getitem_node->tuple]; - } else if (optype == "constant") { - Type checked_type = expr->checked_type_; - ICHECK(checked_type.defined() && checked_type->IsInstance()) - << "Constant checked_type is not defined"; - const auto& t_info = Downcast(checked_type); - const auto& layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); - const auto& weight = - MSCTensor(node_name, t_info->dtype, layout, ArrayUtils::Cast(t_info->shape)); - node_weights.Set("const", weight); - } - std::vector> inputs; - for (const auto& i : input_names) { - inputs.push_back(tensor_input_map_[i]); - } - - // Build outputs - Array outputs; - const auto& layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); - Type checked_type = expr->checked_type_; - if (checked_type.defined() && checked_type->IsInstance()) { - checked_type = Downcast(checked_type)->ret_type; - } - if (checked_type.defined()) { - if (const auto* t_info = checked_type.as()) { - const auto& shape = ArrayUtils::Cast(t_info->shape); - const auto& output = - MSCTensor(node_name + ":" + std::to_string(0), t_info->dtype, layout, shape); - outputs.push_back(output); - } else if (const auto* tuple_info = checked_type.as()) { - Array layouts = StringUtils::Split(layout, ","); - if (layouts.size() == 0) { - layouts = Array(tuple_info->fields.size(), ""); - } - ICHECK_EQ(layouts.size(), tuple_info->fields.size()) - << "Layout " << layout << " msimatch with fileds size " << tuple_info->fields.size(); - size_t field_size = tuple_info->fields.size(); - if (optype == "nn.batch_norm") { - field_size = 1; - } - for (size_t i = 0; i < field_size; i++) { - const auto& t_info = Downcast(tuple_info->fields[i]); - const auto& shape = ArrayUtils::Cast(t_info->shape); - const auto& output = - MSCTensor(node_name + ":" + std::to_string(i), t_info->dtype, layouts[i], shape); - outputs.push_back(output); - } - } else { - LOG(FATAL) << "Unexpected checked_type " << checked_type; - } - } - - // Build node - const auto& node = MSCJoint(nodes_.size(), node_name, shared_ref, optype, attrs, scope, inputs, - outputs, node_weights); - Array output_names; - for (size_t i = 0; i < outputs.size(); i++) { - output_names.push_back(outputs[i]->name); - tensor_input_map_[outputs[i]->name] = std::make_pair(node, i); - } - nodes_.push_back(node); - expr_tensor_map_.Set(expr, output_names); - return node; -} - -void RelayGraphBuilder::VisitExpr_(const relay::ConstantNode* op) { - const auto& node = AddNode(GetRef(op)); - if (HasFuncScope()) { - func_scopes_.top().AddFuncWeight(node->OutputAt(0)->name); - } -} - -void RelayGraphBuilder::VisitExpr_(const relay::FunctionNode* op) { - const auto& name_opt = op->GetAttr(relay::attr::kComposite); - if (name_opt.defined()) { - StartFuncScope(SpanUtils::GetAttr(op->span, msc_attr::kName)); - } - RelaxExprVisitor::VisitExpr_(op); - if (HasFuncScope()) { - AddNode(GetRef(op)); - EndFuncScope(); - } -} - -void RelayGraphBuilder::VisitExpr_(const relay::CallNode* op) { - if (const auto* f_node = op->op.as()) { - const auto& name_opt = f_node->GetAttr(relay::attr::kComposite); - if (name_opt.defined()) { - for (size_t i = 0; i < op->args.size(); i++) { - if (!expr_tensor_map_.count(op->args[i])) { - RelaxExprVisitor::VisitExpr(op->args[i]); - } - ICHECK(expr_tensor_map_.count(op->args[i])) - << "Can not find argument " << relay::PrettyPrint(op->args[i]); - expr_tensor_map_.Set(f_node->params[i], expr_tensor_map_[op->args[i]]); - } - } - } - RelaxExprVisitor::VisitExpr_(op); - if (!HasFuncScope() && op->op->IsInstance()) { - try { - AddNode(GetRef(op)); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to add node from " << relay::PrettyPrint(GetRef(op)) - << " : " << err.message(); - throw err; - } - } - if (op->op->IsInstance() && expr_tensor_map_.count(op->op)) { - expr_tensor_map_.Set(GetRef(op), expr_tensor_map_[op->op]); - } -} - -void RelayGraphBuilder::VisitExpr_(const relay::TupleNode* val) { - RelaxExprVisitor::VisitExpr_(val); - AddNode(GetRef(val)); -} - -void RelayGraphBuilder::VisitExpr_(const relay::TupleGetItemNode* val) { - RelaxExprVisitor::VisitExpr_(val); - AddNode(GetRef(val)); -} - -void RelayGraphBuilder::StartFuncScope(const String& name) { - RelayFuncScope func_scope = RelayFuncScope(name); - func_scopes_.push(func_scope); -} -void RelayGraphBuilder::EndFuncScope() { - ICHECK(HasFuncScope()) << "No FuncScope found"; - func_scopes_.pop(); -} - -bool RelayGraphBuilder::HasFuncScope() { return func_scopes_.size() > 0; } - -Map RelayWeightsExtractor::GetWeights(const relay::Function& func) { - VisitExpr(func); - return weights_; -} - -void RelayWeightsExtractor::VisitExpr_(const relay::ConstantNode* op) { - const auto& name = SpanUtils::GetAttr(op->span, msc_attr::kName); - const auto& layout = SpanUtils::GetAttr(op->span, msc_attr::kLayout); - const auto& t_info = op->tensor_type(); - const auto& shape = ArrayUtils::Cast(t_info->shape); - const auto& weight = MSCTensor(name, t_info->dtype, layout, shape); - weights_.Set(weight, op->data); -} - TVM_REGISTER_GLOBAL("msc.core.BuildFromRelax") - .set_body_typed([](const IRModule& relax_module, const String& entry_name, + .set_body_typed([](const IRModule& module, const String& entry_name, const String& options) -> MSCGraph { - auto builder = RelaxGraphBuilder(relax_module, entry_name, options); + auto builder = GraphBuilder(module, entry_name, options); const auto& func_name = builder.config().byoc_entry.size() > 0 ? String(builder.config().byoc_entry) : entry_name; - const auto& func = Downcast(relax_module->Lookup(func_name)); + const auto& func = Downcast(module->Lookup(func_name)); return builder.Build(func); }); TVM_REGISTER_GLOBAL("msc.core.GetRelaxWeights") - .set_body_typed([](const IRModule& relax_module, - const String& entry_name) -> Map { - const auto& func = Downcast(relax_module->Lookup(entry_name)); - return RelaxWeightsExtractor(relax_module).GetWeights(func); - }); - -TVM_REGISTER_GLOBAL("msc.core.BuildFromRelay") - .set_body_typed([](const IRModule& relay_module, const String& entry_name, - const String& options) -> MSCGraph { - const auto& func = Downcast(relay_module->Lookup(entry_name)); - return RelayGraphBuilder(relay_module, entry_name, options).Build(func); - }); - -TVM_REGISTER_GLOBAL("msc.core.GetRelayWeights") - .set_body_typed([](const IRModule& relay_module, + .set_body_typed([](const IRModule& module, const String& entry_name) -> Map { - const auto& func = Downcast(relay_module->Lookup(entry_name)); - return RelayWeightsExtractor().GetWeights(func); + const auto& func = Downcast(module->Lookup(entry_name)); + return WeightsExtractor(module).GetWeights(func); }); } // namespace msc diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index 9fd855455c1e..bd176dd05101 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -27,8 +27,6 @@ #include #include #include -#include -#include #include #include @@ -48,10 +46,9 @@ namespace tvm { namespace contrib { namespace msc { -using Expr = tvm::RelaxExpr; -using RelaxExprVisitor = tvm::relax::ExprVisitor; -using RelaxExprVisitor = tvm::relay::ExprVisitor; +using namespace tvm::relax; +using Expr = tvm::RelaxExpr; using tvm::runtime::NDArray; /*! @@ -149,7 +146,7 @@ class AttrGetter : public AttrVisitor { Map* attrs_; }; -class RelaxFuncAttrGetter : public RelaxExprVisitor { +class FuncAttrGetter : public ExprVisitor { public: /*! \brief Get the attributes as Map*/ Map GetAttrs(const Expr& expr) { @@ -157,15 +154,15 @@ class RelaxFuncAttrGetter : public RelaxExprVisitor { return attrs_; } - void VisitExpr_(const relax::CallNode* op) final; + void VisitExpr_(const CallNode* op) final; - void VisitExpr_(const relax::TupleGetItemNode* op) final; + void VisitExpr_(const TupleGetItemNode* op) final; private: Map attrs_; }; -class RelaxFuncValueGetter : public RelaxExprVisitor { +class FuncValueGetter : public ExprVisitor { public: /*! \brief Get the attributes from prim value as Map*/ Array GetValues(const Expr& expr) { @@ -173,19 +170,19 @@ class RelaxFuncValueGetter : public RelaxExprVisitor { return values_; } - void VisitExpr_(const relax::CallNode* op) final; + void VisitExpr_(const CallNode* op) final; private: Array values_; }; -class RelaxFuncParamsFinder : public RelaxExprVisitor { +class FuncParamsFinder : public ExprVisitor { public: /*! - * \brief The constructor of RelaxFuncParamsFinder + * \brief The constructor of FuncParamsFinder * \param ref_module the reference module. */ - explicit RelaxFuncParamsFinder(const IRModule& ref_module) : RelaxExprVisitor() { + explicit FuncParamsFinder(const IRModule& ref_module) : ExprVisitor() { ref_module_ = ref_module; } @@ -195,25 +192,23 @@ class RelaxFuncParamsFinder : public RelaxExprVisitor { return params_; } - void VisitBinding_(const relax::VarBindingNode* binding, const relax::FunctionNode* val) final; + void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final; - void VisitExpr_(const relax::CallNode* op) final; + void VisitExpr_(const CallNode* op) final; private: IRModule ref_module_; Map params_; - Map local_funcs_; + Map local_funcs_; }; -class RelaxLayoutsFinder : public RelaxExprVisitor { +class LayoutsFinder : public ExprVisitor { public: /*! - * \brief The constructor of RelaxLayoutsFinder + * \brief The constructor of LayoutsFinder * \param ref_module the reference module. */ - explicit RelaxLayoutsFinder(const IRModule& ref_module) : RelaxExprVisitor() { - ref_module_ = ref_module; - } + explicit LayoutsFinder(const IRModule& ref_module) : ExprVisitor() { ref_module_ = ref_module; } /*! \brief Find the layouts form attrs*/ Map FindLayouts(const Expr& expr) { @@ -221,27 +216,27 @@ class RelaxLayoutsFinder : public RelaxExprVisitor { return layouts_; } - void VisitBinding_(const relax::VarBindingNode* binding, const relax::FunctionNode* val) final; + void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final; - void VisitExpr_(const relax::CallNode* op) final; + void VisitExpr_(const CallNode* op) final; private: IRModule ref_module_; Map layouts_; - Map local_funcs_; + Map local_funcs_; }; -class RelaxGraphBuilder : public RelaxExprVisitor { +class GraphBuilder : public ExprVisitor { public: /*! - * \brief The constructor of RelaxGraphBuilder + * \brief The constructor of GraphBuilder * \param ref_module the reference module. * \param name the name of the graph. * \param options the options of build the graph. */ - explicit RelaxGraphBuilder(const IRModule& ref_module, const String& name, - const std::string& options = "") - : RelaxExprVisitor() { + explicit GraphBuilder(const IRModule& ref_module, const String& name, + const std::string& options = "") + : ExprVisitor() { ref_module_ = ref_module; if (options.size() > 0) { std::istringstream is(options); @@ -250,13 +245,13 @@ class RelaxGraphBuilder : public RelaxExprVisitor { } name_ = config_.graph_name.size() > 0 ? String(config_.graph_name) : name; if (config_.byoc_entry.size() > 0) { - func_params_ = RelaxFuncParamsFinder(ref_module).FindParams(ref_module->Lookup(name)); + func_params_ = FuncParamsFinder(ref_module).FindParams(ref_module->Lookup(name)); } - layouts_ = RelaxLayoutsFinder(ref_module).FindLayouts(ref_module->Lookup(name)); + layouts_ = LayoutsFinder(ref_module).FindLayouts(ref_module->Lookup(name)); } /*! \brief Build MSCGraph from relax function*/ - const MSCGraph Build(const relax::Function& func); + const MSCGraph Build(const Function& func); /*! \brief Get the config of builder */ const MSCRBuildConfig config() { return config_; } @@ -272,35 +267,34 @@ class RelaxGraphBuilder : public RelaxExprVisitor { const Array& parents = Array(), const Map& attrs = Map()); - void VisitBindingBlock(const relax::BindingBlock& block) final; + void VisitBindingBlock(const BindingBlock& block) final; - void VisitExpr_(const relax::ConstantNode* op) final; + void VisitExpr_(const ConstantNode* op) final; - void VisitBinding_(const relax::VarBindingNode* binding, const relax::ConstantNode* val) final; + void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) final; - void VisitBinding_(const relax::VarBindingNode* binding, const relax::ShapeExprNode* val) final; + void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) final; - void VisitBinding_(const relax::VarBindingNode* binding, const relax::CallNode* call_node) final; + void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final; - void VisitBinding_(const relax::VarBindingNode* binding, const relax::TupleNode* val) final; + void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final; - void VisitBinding_(const relax::VarBindingNode* binding, - const relax::TupleGetItemNode* val) final; + void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final; - void VisitBinding_(const relax::VarBindingNode* binding, const relax::VarNode* val) final; + void VisitBinding_(const VarBindingNode* binding, const VarNode* val) final; - void VisitBinding_(const relax::VarBindingNode* binding, const relax::DataflowVarNode* val) final; + void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) final; - void VisitBinding_(const relax::VarBindingNode* binding, const relax::FunctionNode* val) final; + void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final; void VisitPrimExpr(const PrimExpr& prim) final; private: /*! \brief Get the node_name, optype, layout for func*/ - const std::tuple ParseFunc(const relax::Function& func); + const std::tuple ParseFunc(const Function& func); /*! \brief Get the plugin inputs*/ - Array GetPluginInputs(const relax::Expr& expr); + Array GetPluginInputs(const Expr& expr); String name_; IRModule ref_module_; @@ -316,145 +310,38 @@ class RelaxGraphBuilder : public RelaxExprVisitor { std::set setted_blocks_; Array block_stack_; // BYOC maps - Map target_funcs_; + Map target_funcs_; Map func_params_; // prims Array prims_; Map prim_map_; }; -class RelaxWeightsExtractor : public RelaxExprVisitor { +class WeightsExtractor : public ExprVisitor { public: /*! - * \brief The constructor of RelaxGraphBuilder + * \brief The constructor of GraphBuilder * \param ref_module the reference module. * \param name the name of the graph. * \param options the options of build the graph. */ - explicit RelaxWeightsExtractor(const IRModule& ref_module) : RelaxExprVisitor() { + explicit WeightsExtractor(const IRModule& ref_module) : ExprVisitor() { ref_module_ = ref_module; } /*! \brief Visit the constant and save weights */ - Map GetWeights(const relax::Function& func); + Map GetWeights(const Function& func); - void VisitExpr_(const relax::ConstantNode* op) final; + void VisitExpr_(const ConstantNode* op) final; - void VisitExpr_(const relax::CallNode* op) final; + void VisitExpr_(const CallNode* op) final; private: Map weights_; - Map local_funcs_; + Map local_funcs_; IRModule ref_module_; }; -class RelayFuncAttrGetter : public RelaxExprVisitor { - public: - /*! \brief Get the attributes as Map*/ - Map GetAttrs(const Expr& expr) { - RelayFuncAttrGetter::VisitExpr(expr); - return attrs_; - } - - void VisitExpr_(const relay::CallNode* op) final; - - private: - Map attrs_; -}; - -/*! - * \brief A Scope for recording func - */ -class RelayFuncScope { - public: - /*! \brief The constructor */ - explicit RelayFuncScope(const String& name) : name_(name) {} - - /*! \brief Add a weight */ - void AddFuncWeight(const String& weight) { func_weights_.push_back(weight); } - - /*! \brief Get weights */ - const Array GetFuncWeights() { return func_weights_; } - - private: - String name_; - Array func_weights_; -}; - -class RelayGraphBuilder : public RelaxExprVisitor { - public: - /*! - * \brief The constructor of RelayGraphBuilder - * \param ref_module the reference module. - * \param name the name of the graph. - * \param options the options of build the graph. - */ - explicit RelayGraphBuilder(const IRModule& ref_module, const String& name, - const std::string& options = "") - : RelaxExprVisitor() { - ref_module_ = ref_module; - if (options.size() > 0) { - std::istringstream is(options); - dmlc::JSONReader reader(&is); - reader.Read(&config_); - } - while (!func_scopes_.empty()) { - func_scopes_.pop(); - } - name_ = config_.graph_name.size() > 0 ? String(config_.graph_name) : name; - } - - /*! \brief Build MSCGraph from relax function*/ - MSCGraph Build(const relay::Function& func); - - /*! \brief Get the config of builder */ - const MSCRBuildConfig config() { return config_; } - - /*! \brief Create and add MSCJoint from expr*/ - MSCJoint AddNode(const Expr& expr, const String& name = ""); - - void VisitExpr_(const relay::ConstantNode* op) final; - - void VisitExpr_(const relay::FunctionNode* op) final; - - void VisitExpr_(const relay::CallNode* op) final; - - void VisitExpr_(const relay::TupleNode* val) final; - - void VisitExpr_(const relay::TupleGetItemNode* val) final; - - protected: - /*! \brief Start a func scope */ - void StartFuncScope(const String& scope); - - /*! \brief End a func scope */ - void EndFuncScope(); - - /*! \brief Check if has func scopes left */ - bool HasFuncScope(); - - private: - String name_; - MSCRBuildConfig config_; - IRModule ref_module_; - Array nodes_; - Map weights_; - Map> expr_tensor_map_; - std::unordered_map> tensor_input_map_; - std::stack func_scopes_; -}; - -class RelayWeightsExtractor : public RelaxExprVisitor { - public: - /*! \brief Visit the constant and save weights*/ - Map GetWeights(const relay::Function& func); - - void VisitExpr_(const relay::ConstantNode* op) final; - - private: - Map weights_; -}; - } // namespace msc } // namespace contrib } // namespace tvm diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc index 6256fae05f83..da5a60d88e75 100644 --- a/src/contrib/msc/core/transform/bind_named_params.cc +++ b/src/contrib/msc/core/transform/bind_named_params.cc @@ -17,7 +17,6 @@ * under the License. */ -#include #include #include #include diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index 0418513c145a..ede6068d541c 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -25,9 +25,6 @@ #include #include #include -#include -#include -#include #include "../utils.h" @@ -333,224 +330,4 @@ TVM_REGISTER_GLOBAL("relax.transform.SetRelaxExprName").set_body_typed(SetRelaxE } // namespace transform } // namespace relax - -namespace relay { - -/*! - * \brief Name setter for Relay - */ -class RelaxExprNameSetter : public ExprVisitor { - public: - explicit RelaxExprNameSetter(const IRModule& ref_module) : ref_module_(ref_module) {} - - void VisitExpr_(const ConstantNode* op) final { - ExprVisitor::VisitExpr_(op); - const String& unique_name = GetUniqueName(GetRef(op), "const"); - if (unique_name != SpanUtils::GetAttr(op->span, msc_attr::kName)) { - op->span = SpanUtils::SetAttr(op->span, msc_attr::kName, unique_name); - } - } - - void VisitExpr_(const TupleNode* op) final { - ExprVisitor::VisitExpr_(op); - const String& unique_name = GetUniqueName(GetRef(op), "tuple"); - if (unique_name != SpanUtils::GetAttr(op->span, msc_attr::kName)) { - op->span = SpanUtils::SetAttr(op->span, msc_attr::kName, unique_name); - } - } - - void VisitExpr_(const TupleGetItemNode* op) final { - ExprVisitor::VisitExpr_(op); - const String& tuple_name = SpanUtils::GetAttr(op->tuple->span, msc_attr::kName); - const String& unique_name = tuple_name + "." + std::to_string(op->index); - if (unique_name != SpanUtils::GetAttr(op->span, msc_attr::kName)) { - op->span = SpanUtils::SetAttr(op->span, msc_attr::kName, unique_name); - } - } - - void VisitExpr_(const FunctionNode* op) final { - ExprVisitor::VisitExpr_(op); - const auto& name_opt = op->GetAttr(attr::kComposite); - const String& name_hint = name_opt.defined() ? name_opt.value() : "func"; - const String& unique_name = GetUniqueName(GetRef(op), name_hint); - if (unique_name != SpanUtils::GetAttr(op->span, msc_attr::kName)) { - op->span = SpanUtils::SetAttr(op->span, msc_attr::kName, unique_name); - } - } - - void VisitExpr_(const CallNode* op) final { - ExprVisitor::VisitExpr_(op); - String name_hint, optype; - if (const auto* op_node = op->op.as()) { - const std::string& op_name = op_node->name; - int rpos = op_name.rfind("."); - name_hint = op_name.substr(rpos + 1); - optype = StringUtils::Replace(op_node->name, "relay.", ""); - } else if (const auto* v_node = op->op.as()) { - const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - const auto& name_opt = func->GetAttr(attr::kComposite); - if (name_opt.defined()) { - optype = name_opt.value(); - name_hint = optype; - ExprVisitor::VisitExpr(func); - } else { - optype = "extern_func"; - name_hint = v_node->name_hint; - } - } - if (name_hint.size() > 0) { - // set name - const String& unique_name = GetUniqueName(GetRef(op), name_hint); - if (unique_name != SpanUtils::GetAttr(op->span, msc_attr::kName)) { - op->span = SpanUtils::SetAttr(op->span, msc_attr::kName, unique_name); - } - // set constant consumer && shared_ref - Array input_types; - try { - input_types = ExprUtils::GetInputTypes(optype, op->args.size(), false); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to GetInputTypes for " << GetRef(op) << " : " - << err.message(); - throw err; - } - for (size_t i = 0; i < input_types.size(); i++) { - if (input_types[i] == "input") { - continue; - } - if (const auto* c_node = op->args[i].as()) { - const String& const_name = SpanUtils::GetAttr(c_node->span, msc_attr::kName); - if (constant_consumers_.count(const_name)) { - op->span = - SpanUtils::SetAttr(op->span, msc_attr::kSharedRef, constant_consumers_[const_name]); - } else { - constant_consumers_.Set(const_name, unique_name); - } - } - } - } - } - - private: - const String GetUniqueName(const Expr& expr, const String& name_hint) { - String expr_name = SpanUtils::GetAttr(expr->span, msc_attr::kName); - if (expr_name.size() == 0) { - expr_name = name_hint; - } - if (!setted_names_.count(expr_name)) { - setted_names_.Set(expr_name, expr); - return expr_name; - } - if (setted_names_[expr_name] == expr) { - return expr_name; - } - int cnt = 1; - while (setted_names_.count(expr_name + "_" + std::to_string(cnt)) && - setted_names_[expr_name + "_" + std::to_string(cnt)] != expr) { - cnt++; - } - expr_name = expr_name + "_" + std::to_string(cnt); - if (!setted_names_.count(expr_name)) { - setted_names_.Set(expr_name, expr); - } - return expr_name; - } - - Map setted_names_; - Map constant_consumers_; - IRModule ref_module_; -}; // class ExprNameSetter - -void SetRelaxExprName(const IRModule& ref_module, const Expr& e) { - RelaxExprNameSetter(ref_module).VisitExpr(e); -} - -/*! - * \brief Name binder for Relay - */ -class RelaxExprNameBinder : public ExprVisitor { - public: - explicit RelaxExprNameBinder(const String& name_key, const String& seperator) - : name_key_(name_key), seperator_(seperator) {} - - void VisitExpr_(const ConstantNode* op) final { - if (op->span.defined()) { - BindName(GetRef(op)); - } - } - - void VisitExpr_(const CallNode* op) final { - if (op->span.defined()) { - BindName(GetRef(op)); - } - ExprVisitor::VisitExpr_(op); - } - - private: - void BindName(const Expr& expr) { - const auto& name = expr->span->source_name->name; - String valid_name; - if (name_key_.size() == 0) { - valid_name = name; - expr->span = Span(SourceName::Get(""), expr->span->line, expr->span->end_line, - expr->span->column, expr->span->end_column); - } else { - String right = std::get<1>(StringUtils::SplitOnce(name, name_key_)); - if (right.size() > 0) { - valid_name = std::get<0>(StringUtils::SplitOnce(name, seperator_)); - if (valid_name.size() > 0) { - const auto& new_source = StringUtils::Replace(name, name_key_ + valid_name, ""); - expr->span = Span(SourceName::Get(new_source), expr->span->line, expr->span->end_line, - expr->span->column, expr->span->end_column); - } - } - } - if (valid_name.size() > 0) { - if (setted_names_.count(valid_name)) { - int cnt = 1; - while (setted_names_.count(valid_name + "_" + std::to_string(cnt)) && - setted_names_[valid_name + "_" + std::to_string(cnt)] != expr) { - cnt++; - } - valid_name = valid_name + "_" + std::to_string(cnt); - } - setted_names_.Set(valid_name, expr); - expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, valid_name); - } - } - - Map setted_names_; - String name_key_; - String seperator_; -}; // class ExprNameBinder - -void BindRelaxExprName(const Expr& e, const String& name_key, const String& seperator) { - RelaxExprNameBinder(name_key, seperator).VisitExpr(e); -} - -namespace transform { - -Pass SetRelaxExprName(const String& entry_name) { - runtime::TypedPackedFunc pass_func = [=](IRModule m, - PassContext pc) { - relay::SetRelaxExprName(m, m->Lookup(entry_name)); - return m; - }; - return CreateModulePass(pass_func, 0, "SetRelaxExprName", {}); -} - -TVM_REGISTER_GLOBAL("relay._transform.SetRelaxExprName").set_body_typed(SetRelaxExprName); - -Pass BindRelaxExprName(const String& name_key, const String& seperator, const String& entry_name) { - runtime::TypedPackedFunc pass_func = [=](IRModule m, - PassContext pc) { - relay::BindRelaxExprName(m->Lookup(entry_name), name_key, seperator); - return m; - }; - return CreateModulePass(pass_func, 0, "BindRelaxExprName", {}); -} - -TVM_REGISTER_GLOBAL("relay._transform.BindRelaxExprName").set_body_typed(BindRelaxExprName); - -} // namespace transform -} // namespace relay } // namespace tvm diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index 1e846b0b3a61..c1348c4016a8 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -29,6 +29,8 @@ namespace tvm { namespace contrib { namespace msc { +using namespace tvm::relax; + size_t CommonUtils::GetIndex(int index, size_t max_size) { size_t v_index; if (index < 0) { @@ -278,9 +280,9 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { obj_string = obj_string + ","; } } - } else if (const auto* n = obj.as()) { + } else if (const auto* n = obj.as()) { obj_string = ToString(n->value); - } else if (const auto* n = obj.as()) { + } else if (const auto* n = obj.as()) { obj_string = ToString(n->fields); } else { std::ostringstream obj_des; @@ -489,16 +491,11 @@ const Array ExprUtils::GetInputTypes(const String& optype, size_t inputs return input_types; } -const Array ExprUtils::GetInputTypes(const RelaxCall& call) { +const Array ExprUtils::GetInputTypes(const Call& call) { const String& optype = StringUtils::Replace(Downcast(call->op)->name, "relax.", ""); return GetInputTypes(optype, call->args.size(), true); } -const Array ExprUtils::GetInputTypes(const RelayCall& call) { - const String& optype = StringUtils::Replace(Downcast(call->op)->name, "relay.", ""); - return GetInputTypes(optype, call->args.size(), false); -} - const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) { const auto& name = SpanUtils::GetAttr(expr->span, msc_attr::kName); if (suffix.size() > 0) { @@ -507,7 +504,7 @@ const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) { return name; } -const Array ExprUtils::GetShape(const relax::TensorStructInfo& sinfo, bool as_int) { +const Array ExprUtils::GetShape(const TensorStructInfo& sinfo, bool as_int) { const auto& shape_opt = sinfo->GetShape(); if (!shape_opt.defined()) { return Array(); @@ -523,11 +520,11 @@ const Array ExprUtils::GetShape(const relax::TensorStructInfo& sinfo, } const Array ExprUtils::GetShape(const Expr& expr, bool as_int) { - return GetShape(Downcast(relax::GetStructInfo(expr)), as_int); + return GetShape(Downcast(GetStructInfo(expr)), as_int); } const DataType ExprUtils::GetDataType(const Expr& expr) { - return Downcast(relax::GetStructInfo(expr))->dtype; + return Downcast(GetStructInfo(expr))->dtype; } TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr); diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index 9bcaba2a271f..41566883036c 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -27,7 +27,6 @@ #include #include #include -#include #include #include @@ -37,9 +36,8 @@ namespace tvm { namespace contrib { namespace msc { +using namespace tvm::relax; using Expr = tvm::RelaxExpr; -using RelaxCall = tvm::relax::Call; -using RelayCall = tvm::relay::Call; namespace msc_attr { /*! \brief Mark the name for the expr. */ @@ -324,13 +322,7 @@ class ExprUtils { * \brief Get the input types of call. * \return The input types. */ - TVM_DLL static const Array GetInputTypes(const RelaxCall& call); - - /*! - * \brief Get the input types of call. - * \return The input types. - */ - TVM_DLL static const Array GetInputTypes(const RelayCall& call); + TVM_DLL static const Array GetInputTypes(const Call& call); /*! * \brief Get the scalar value of ndarray. @@ -375,16 +367,7 @@ class ExprUtils { * \return The scalar value. */ template - TVM_DLL static const T GetScalar(const relax::Constant& constant, size_t i = 0) { - return GetScalar(constant->data, i); - } - - /*! - * \brief Get the scalar value of relay constant. - * \return The scalar value. - */ - template - TVM_DLL static const T GetScalar(const relay::Constant& constant, size_t i = 0) { + TVM_DLL static const T GetScalar(const Constant& constant, size_t i = 0) { return GetScalar(constant->data, i); } @@ -398,8 +381,7 @@ class ExprUtils { * \brief Get shape of expr. * \return The shape. */ - TVM_DLL static const Array GetShape(const relax::TensorStructInfo& sinfo, - bool as_int = true); + TVM_DLL static const Array GetShape(const TensorStructInfo& sinfo, bool as_int = true); TVM_DLL static const Array GetShape(const Expr& expr, bool as_int = true); /*! diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 542e15d06c3c..94dfec7ea621 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -22,6 +22,7 @@ * \brief Pass for transform the function to tensorrt. */ +#include #include #include #include diff --git a/tests/python/contrib/test_msc/test_translate_tensorflow.py b/tests/python/contrib/test_msc/test_translate_tensorflow.py deleted file mode 100644 index 857cffbbd87a..000000000000 --- a/tests/python/contrib/test_msc/test_translate_tensorflow.py +++ /dev/null @@ -1,1691 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=deprecated-module - -"""Test translate from tensorflow.""" - -import pytest -from packaging import version as package_version -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variables - - -try: - import tensorflow.compat.v1 as tf - - tf.disable_v2_behavior() -except ImportError: - import tensorflow as tf - -import tvm -import tvm.testing -import tvm.relay.testing.tf as tf_testing -from tvm.contrib.msc.framework.tensorflow.frontend import translate -from tvm.contrib.msc.framework.tensorflow import codegen - - -# Only allow TF to run on half the GPU RAM to save the other half -# For TVM -gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5) -gpu_sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) -gpu_sess.close() - - -def convert_to_list(x): - if not isinstance(x, list): - x = [x] - return x - - -def run_tf_graph(sess, input_data, input_node, output_node): - """Generic function to execute tensorflow""" - - input_data = convert_to_list(input_data) - input_node = convert_to_list(input_node) - output_node = convert_to_list(output_node) - - tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node] - - input_dict = {e: input_data[i] for i, e in enumerate(input_node)} - if len(input_node) == 1 and input_node[0] == "": - output_data = sess.run(tensor) - else: - output_data = sess.run(tensor, input_dict) - return output_data - - -def get_graph_def(in_data, in_name, out_name): - """Get tf.GraphDef for translate""" - - def name_without_num(name): - return name.split(":")[0] if ":" in name else name - - out_name = convert_to_list(out_name) - out_node = [name_without_num(name) for name in out_name] - in_data = convert_to_list(in_data) - in_name = convert_to_list(in_name) - - with tf.Session() as sess: - sess.run(variables.global_variables_initializer()) - final_graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) - golden = run_tf_graph(sess, in_data, in_name, out_name) - return final_graph_def, golden - - -def verify_model(graph_def, golden, in_data, in_name, out_name, use_out_name=True): - """Generic function to generate and compare tensorflow and MSC-TFV1 output""" - - out_name = convert_to_list(out_name) - in_data = convert_to_list(in_data) - in_name = convert_to_list(in_name) - shape_dict = {i: d.shape for i, d in zip(in_name, in_data)} - graph, weights = translate.from_tensorflow(graph_def, shape_dict, out_name) - with tf.Graph().as_default(): - outputs = codegen.to_tensorflow(graph, weights) - with tf.Session() as sess: - sess.run(variables.global_variables_initializer()) - if not use_out_name: - out_name = [o.name for o in convert_to_list(outputs)] - result = run_tf_graph(sess, in_data, in_name, out_name) - - golden = convert_to_list(golden) - result = convert_to_list(result) - assert len(golden) == len(result), "golden {} mismatch with result {}".format( - len(golden), len(result) - ) - for gol_r, new_r in zip(golden, result): - if isinstance(gol_r, np.ndarray): - tvm.testing.assert_allclose(gol_r, new_r, atol=1e-5, rtol=1e-5) - else: - assert gol_r == new_r - - -def _test_pooling(input_shape, **kwargs): - """One iteration of pool operation with given shapes and attributes""" - - x = -np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1 - - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=input_shape, dtype="float32") - nn_ops.pool(in_data, **kwargs) - out_name = "max_pool:0" if kwargs["pooling_type"] == "MAX" else "avg_pool:0" - io_info = {"in_data": x, "in_name": "Placeholder:0", "out_name": out_name} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_pooling(): - """test tensorflow translator for pooling""" - - for pool_type in ["AVG", "MAX"]: - _test_pooling( - input_shape=[2, 9, 10, 2], - window_shape=[2, 1], - padding="SAME", - pooling_type=pool_type, - dilation_rate=[1, 1], - strides=[1, 1], - ) - - _test_pooling( - input_shape=[2, 9, 10, 2], - window_shape=[2, 1], - padding="VALID", - pooling_type=pool_type, - dilation_rate=[1, 1], - strides=[1, 1], - ) - - _test_pooling( - input_shape=[1, 2, 1], - window_shape=[1], - padding="VALID", - pooling_type=pool_type, - dilation_rate=[1], - ) - - # Explicit padding - if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"): - _test_pooling( - input_shape=[2, 9, 10, 2], - window_shape=[4, 4], - padding=[[0, 0], [0, 1], [2, 3], [0, 0]], - pooling_type="MAX", - dilation_rate=[1, 1], - strides=[1, 1], - ) - - -def _test_convolution( - opname, - tensor_in_sizes, - filter_in_sizes, - dilations, - strides, - padding, - data_format, -): - """One iteration of convolution with given shapes and attributes""" - total_size_1 = np.prod(tensor_in_sizes) - total_size_2 = np.prod(filter_in_sizes) - # Initializes the input tensor with array containing incrementing - # numbers from 1. - data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] - filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)] - - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32") - in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype="float32") - if data_format == "NHWC": - strides = [1] + strides + [1] - dilations = [1] + dilations + [1] - else: - strides = [1, 1] + strides - dilations = [1, 1] + dilations - - if opname == "conv": - nn_ops.conv2d( - in_data, - in_filter, - strides=strides, - dilations=dilations, - padding=padding, - data_format=data_format, - ) - io_info = { - "in_data": np.reshape(data_array, tensor_in_sizes).astype("float32"), - "in_name": "Placeholder:0", - "out_name": "Conv2D:0", - } - graph_def, golden = get_graph_def(**io_info) - else: - nn_ops.depthwise_conv2d_native( - in_data, - in_filter, - strides=strides, - dilations=dilations, - padding=padding, - data_format=data_format, - ) - io_info = { - "in_data": np.reshape(data_array, tensor_in_sizes).astype("float32"), - "in_name": "Placeholder:0", - "out_name": "DepthwiseConv2dNative:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_convolution(): - """test tensorflow translator for convolution""" - - _test_convolution("conv", [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC") - _test_convolution("conv", [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC") - _test_convolution("depthwise", [4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], "SAME", "NHWC") - _test_convolution("depthwise", [4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], "VALID", "NHWC") - - # Explicit padding - if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"): - _test_convolution( - "conv", - [4, 8, 8, 16], - [1, 1, 16, 32], - [1, 1], - [1, 1], - [[0, 0], [2, 3], [0, 1], [0, 0]], - "NHWC", - ) - _test_convolution( - "depthwise", - [4, 8, 8, 16], - [1, 1, 16, 1], - [1, 1], - [1, 1], - [[0, 0], [2, 3], [0, 1], [0, 0]], - "NHWC", - ) - - -def _test_biasadd(tensor_in_sizes, data_format): - """One iteration of biasadd with given shapes and attributes""" - - total_size_1 = 1 - for s in tensor_in_sizes: - total_size_1 *= s - tensor_bias_sizes = [tensor_in_sizes[1]] if data_format == "NCHW" else [tensor_in_sizes[3]] - total_size_2 = tensor_bias_sizes[0] - # Initializes the input tensor with array containing incrementing - # numbers from 1. - data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] - bias_array = [f * 1.0 for f in range(1, total_size_2 + 1)] - - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32") - in_bias = constant_op.constant(bias_array, shape=tensor_bias_sizes, dtype="float32") - nn_ops.bias_add(in_data, in_bias, data_format=data_format) - io_info = { - "in_data": np.reshape(data_array, tensor_in_sizes).astype("float32"), - "in_name": "Placeholder:0", - "out_name": "BiasAdd:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_biasadd(): - """test tensorflow translator for bias_add""" - - _test_biasadd([4, 8, 8, 176], "NHWC") - - -def _test_where_with_broadcast(in_shape, cond_shape): - choice_list = list(np.arange(10).astype("float32")) - t_1 = np.random.choice(choice_list, size=cond_shape) - t_2 = np.random.choice(choice_list, size=cond_shape) - x = np.random.choice(choice_list, size=in_shape) - y = np.random.choice(choice_list, size=in_shape) - - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=cond_shape, dtype="float32", name="in1") - in2 = tf.placeholder(shape=cond_shape, dtype="float32", name="in2") - condition = math_ops.less(in1, in2, name="less") - lhs = tf.placeholder(shape=in_shape, dtype="float32", name="x") - rhs = tf.placeholder(shape=in_shape, dtype="float32", name="y") - out = tf.where(condition, lhs, rhs) - io_info = { - "in_data": [t_1, t_2, x, y], - "in_name": ["in1:0", "in2:0", "x:0", "y:0"], - "out_name": out.name, - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_where_with_broadcast(): - """test tensorflow translator for where""" - - _test_where_with_broadcast((5, 2), (5,)) - _test_where_with_broadcast((3, 2, 5), (3,)) - - -def _test_reshape(data, out_shape): - """One iteration of reshape operation with given data and out shape""" - - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - array_ops.reshape(in_data, out_shape) - io_info = {"in_data": data, "in_name": "Placeholder:0", "out_name": "Reshape:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def _test_reshape_with_call(): - """relay.expr.Call as shape""" - data = np.zeros((6, 4, 2)) - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - out_shape = tf.constant([1, 2, 3], dtype="int32") - out_shape = tf.multiply(out_shape, 2) - array_ops.reshape(in_data, out_shape) - io_info = {"in_data": data, "in_name": "Placeholder:0", "out_name": "Reshape:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def _test_reshape_like(data, shape_like): - """A special case for reshape.""" - - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - in_shape_like = array_ops.placeholder(shape=shape_like.shape, dtype=data.dtype) - out_shape = array_ops.shape(in_shape_like) - array_ops.reshape(in_data, out_shape) - io_info = {"in_data": data, "in_name": "Placeholder:0", "out_name": "Reshape:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_reshape(): - """test tensorflow translator for reshape""" - - _test_reshape(np.arange(6.0), [2, 3]) - _test_reshape(np.arange(6), [-1, 2]) - _test_reshape_with_call() - _test_reshape_like(np.zeros((3, 6)), np.zeros((9, 2))) - - -def _test_sigmoid(data): - """One iteration of sigmoid""" - - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - _ = math_ops.sigmoid(in_data) - io_info = {"in_data": data, "in_name": "Placeholder:0", "out_name": "Sigmoid:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_sigmoid(): - """test tensorflow translator for concat""" - - _test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype("float32")) - - -def _test_argx(func, data, **kwargs): - with tf.Graph().as_default(): - inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0") - func(inp, name="argx", **kwargs) - io_info = {"in_data": data, "in_name": "c0:0", "out_name": "argx:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_argx(): - """test tensorflow translator for argmax/argmin""" - - data = np.random.uniform(size=(8, 4, 9)).astype("float32") - for output_type in [tf.int64, tf.int32]: - _test_argx(tf.argmax, data=data, axis=1, output_type=output_type) - _test_argx(tf.argmin, data=data, axis=1, output_type=output_type) - - -def _test_matmul(i, j, k, transpose_a=False, transpose_b=False): - """One iteration of matmul""" - - a_shape_init = [i, j] - b_shape_init = [j, k] - a_shape = [] + (a_shape_init[::-1] if transpose_a else a_shape_init) - b_shape = [] + (b_shape_init[::-1] if transpose_b else b_shape_init) - - with tf.Graph().as_default(): - a_in = tf.placeholder(shape=a_shape, dtype="float32", name="A") - b_in = tf.placeholder(shape=b_shape, dtype="float32", name="B") - result = tf.matmul(a_in, b_in, transpose_a=transpose_a, transpose_b=transpose_b) - - a_np = np.random.uniform(high=5.0, size=a_shape).astype("float32") - b_np = np.random.uniform(high=5.0, size=b_shape).astype("float32") - io_info = { - "in_data": [a_np, b_np], - "in_name": [a_in.name, b_in.name], - "out_name": result.name, - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info, use_out_name=False) - - -@pytest.mark.skip(reason="Failed due to tf and tflite upgrade.") -def test_matmul(): - """test tensorflow translator for matmul""" - - _test_matmul(1, 3, 6) - _test_matmul(1, 3, 6, True, True) - _test_matmul(1, 3, 6, True, False) - _test_matmul(1, 3, 6, False, True) - - -def _test_batch_matmul(a_shape, b_shape, adjoint_a=False, adjoint_b=False): - with tf.Graph().as_default(): - a_in = tf.placeholder(shape=a_shape, dtype="float32", name="A") - b_in = tf.placeholder(shape=b_shape, dtype="float32", name="B") - result = tf.matmul(a_in, b_in, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name="batchmatmul") - - a_np = np.random.uniform(high=5.0, size=a_shape).astype("float32") - b_np = np.random.uniform(high=5.0, size=b_shape).astype("float32") - io_info = { - "in_data": [a_np, b_np], - "in_name": [a_in.name, b_in.name], - "out_name": result.name, - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_batch_matmul(): - """test tensorflow translator for batch_matmul""" - - _test_batch_matmul((3, 5, 4), (3, 4, 5)) - _test_batch_matmul((3, 5, 4), (3, 4, 5), True, True) - _test_batch_matmul((3, 5, 4), (3, 5, 4), True, False) - _test_batch_matmul((3, 5, 4), (3, 5, 4), False, True) - - -def _test_stridedslice( - ip_shape, - begin, - end, - stride, - begin_mask=0, - end_mask=0, - new_axis_mask=0, - shrink_axis_mask=0, - ellipsis_mask=0, -): - """One iteration of a Stridedslice""" - - tf.reset_default_graph() - np_data = np.random.uniform(size=ip_shape).astype("float32") - with tf.Graph().as_default(): - in_data = tf.placeholder("float32", ip_shape, name="in_data") - tf.strided_slice( - in_data, - begin, - end, - stride, - begin_mask=begin_mask, - end_mask=end_mask, - new_axis_mask=new_axis_mask, - shrink_axis_mask=shrink_axis_mask, - ellipsis_mask=ellipsis_mask, - name="strided_slice", - ) - io_info = {"in_data": np_data, "in_name": "in_data:0", "out_name": "strided_slice:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_stridedslice(): - """test tensorflow translator for stridedslice""" - - _test_stridedslice([2, 3, 4], [0], [1], [1], shrink_axis_mask=8) - _test_stridedslice([3, 4, 3], [1, -1, 0], [4, -5, 3], [2, -1, 1]) - _test_stridedslice([3, 4, 3], [1, 0], [4, 3], [2, 1], ellipsis_mask=8) - _test_stridedslice([3, 4, 3], [1, 1, 0], [4, 4, 2], [2, 1, 1], new_axis_mask=5) - _test_stridedslice( - [3, 4, 5, 4, 5, 6], - [0, 0, 1, 2, 1], - [2, 3, 4, 5, 3], - [1, 1, 2, 2, 1], - shrink_axis_mask=5, - new_axis_mask=1, - ellipsis_mask=2, - begin_mask=8, - end_mask=8, - ) - - -def _test_divide(ip_shape, dtype): - np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype) - np_denomin = np.random.uniform(1, 100, size=ip_shape).astype(dtype) - tf.reset_default_graph() - with tf.Graph().as_default(): - numerator = tf.placeholder(dtype, ip_shape, name="numer") - denominator = tf.placeholder(dtype, ip_shape, name="denomin") - tf.math.divide(numerator, denominator, name="RealDiv") - io_info = { - "in_data": [np_numer, np_denomin], - "in_name": ["numer:0", "denomin:0"], - "out_name": "RealDiv:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def _test_floordiv(ip_shape, dtype): - np_numer = np.random.uniform(1, 100, size=ip_shape).astype(dtype) - tf.reset_default_graph() - with tf.Graph().as_default(): - numerator = tf.placeholder(dtype, ip_shape, name="numer") - tf.math.floordiv(numerator, tf.constant(5, dtype=dtype), name="FloorDiv") - io_info = {"in_data": [np_numer], "in_name": ["numer:0"], "out_name": "FloorDiv:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_divide(): - """test tensorflow translator for div""" - - _test_divide((4, 3, 7), "float32") - _test_divide((4, 3, 7), "int32") - _test_floordiv((4, 3, 7), "float32") - _test_floordiv((4, 3, 7), "int32") - - -def _test_gather(ip_shape, indice_shape, indice_value, axis, batch_dims): - """One iteration of a GatherV2""" - - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder("float32", ip_shape, name="in_data") - indices = tf.placeholder("int32", indice_shape, name="indices") - out = tf.gather(in_data, indices, axis=axis, batch_dims=batch_dims) - np_data = np.random.uniform(1, 10, size=ip_shape).astype("float32") - - def _fill_indices(indice_value): - indices = np.array(ip_shape, dtype="float32") - if isinstance(indice_value, int): - indices = np.array([indice_value], dtype="int32") - else: - indices = np.asarray(indice_value, dtype="int32") - return indices - - np_indices = _fill_indices(indice_value) - io_info = { - "in_data": [np_data, np_indices], - "in_name": ["in_data:0", "indices:0"], - "out_name": out.name, - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_gather(): - """test tensorflow translator for gather""" - - _test_gather((4,), (1,), 1, 0, 0) - _test_gather((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 0, 0) - _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 0) - - -def _test_split(in_shape, axis, num_or_size_splits): - """One iteration of a Split""" - np_data = np.random.uniform(-5, 5, size=in_shape).astype("float32") - - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder("float32", in_shape, name="in_data") - _ = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits - split = tf.split(in_data, num_or_size_splits, axis=axis) - relu = [tf.nn.relu(i) for i in split] - io_info = { - "in_data": [np_data], - "in_name": ["in_data:0"], - "out_name": [n.name for n in relu], - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - # and now test together with concat - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder("float32", in_shape, name="in_data") - splitted = tf.split(in_data, num_or_size_splits, axis=axis) - concat = tf.concat(splitted, axis) - io_info = { - "in_data": [np_data], - "in_name": ["in_data:0"], - "out_name": concat.name, - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_split(): - """test tensorflow translator for split""" - - _test_split((6, 1, 3, 5), 0, 3) - _test_split((6, 1, 3, 5), -4, 3) - _test_split((3, 6, 4), -2, [1, 4, 1]) - - -def _test_tile(in_shape, multiples): - np_data = np.random.uniform(-5, 5, size=in_shape).astype("float32") - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder("float32", in_shape, name="in_data") - tf.tile(in_data, multiples=multiples, name="tile") - io_info = {"in_data": np_data, "in_name": "in_data:0", "out_name": "tile:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_tile(): - """test tensorflow translator for tile""" - - _test_tile((2, 2), (2, 3)) - - -def _test_clip_by_value(ip_shape, clip_value_min, clip_value_max): - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder("float32", ip_shape, name="in_data") - tf.clip_by_value(in_data, clip_value_min, clip_value_max, name="ClipByValue") - np_data = np.random.uniform(-100, 100, size=ip_shape).astype("float32") - io_info = {"in_data": np_data, "in_name": "in_data:0", "out_name": "ClipByValue:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_clip_by_value(): - """test tensorflow translator for clip""" - - _test_clip_by_value((4,), 0.1, 5.0) - - -def test_multi_input(): - """test tensorflow translator for multi input""" - - with tf.Graph().as_default(): - in1 = tf.placeholder(tf.int32, shape=[3, 3], name="in1") - in2 = tf.placeholder(tf.int32, shape=[3, 3], name="in2") - in3 = tf.placeholder(tf.int32, shape=[3, 3], name="in3") - in4 = tf.placeholder(tf.int32, shape=[3, 3], name="in4") - - out1 = tf.add(in1, in2, name="out1") - out2 = tf.subtract(in3, in4, name="out2") - _ = tf.multiply(out1, out2, name="out") - in_data = np.arange(9, dtype="int32").reshape([3, 3]) - io_info = { - "in_data": [in_data, in_data, in_data, in_data], - "in_name": ["in1:0", "in2:0", "in3:0", "in4:0"], - "out_name": "out:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_multi_output(): - """test tensorflow translator for multi output""" - - with tf.Graph().as_default(): - in1 = tf.placeholder(tf.int32, shape=[3, 3], name="in1") - in2 = tf.placeholder(tf.int32, shape=[3, 3], name="in2") - in3 = tf.placeholder(tf.int32, shape=[3, 3], name="in3") - in4 = tf.placeholder(tf.int32, shape=[3, 3], name="in4") - - _ = tf.add(in1, in2, name="out1") - _ = tf.subtract(in3, in4, name="out2") - in_data = np.arange(9, dtype="int32").reshape([3, 3]) - io_info = { - "in_data": [in_data] * 4, - "in_name": ["in1:0", "in2:0", "in3:0", "in4:0"], - "out_name": ["out1:0", "out2:0"], - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def _test_resize_bilinear(in_shape, to_shape, align_corners): - """One iteration of resize bilinear""" - - data = np.random.uniform(size=in_shape).astype("float32") - shape_data = np.array(to_shape).astype("int32") - - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - shape_data = constant_op.constant( - shape_data, shape=shape_data.shape, dtype=shape_data.dtype - ) - tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners) - io_info = {"in_data": data, "in_name": "Placeholder:0", "out_name": "ResizeBilinear:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def _test_resize_nearest_neighbor(in_shape, to_shape): - """One iteration of resize nearest neighbor""" - - data = np.random.uniform(size=in_shape).astype("float32") - shape_data = np.array(to_shape).astype("int32") - - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - shape_data = constant_op.constant( - shape_data, shape=shape_data.shape, dtype=shape_data.dtype - ) - tf.image.resize_nearest_neighbor(in_data, shape_data, name="resize_nearest_neighbor") - io_info = { - "in_data": data, - "in_name": "Placeholder:0", - "out_name": "resize_nearest_neighbor:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_resize(): - """test tensorflow translator for resize""" - - _test_resize_bilinear((4, 32, 32, 3), [50, 50], False) - _test_resize_bilinear((6, 32, 32, 3), [20, 20], True) - _test_resize_nearest_neighbor((6, 32, 32, 3), [20, 20]) - - -def _test_broadcast_to(in_shape, to_shape): - """One iteration of broadcast_to""" - - data = np.random.uniform(size=in_shape).astype("float32") - shape_data = np.array(to_shape).astype("int32") - - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - shape_data = constant_op.constant( - shape_data, shape=shape_data.shape, dtype=shape_data.dtype - ) - tf.broadcast_to(in_data, shape_data) - io_info = {"in_data": data, "in_name": "Placeholder:0", "out_name": "BroadcastTo:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_broadcast_to(): - """test tensorflow translator for broadcast_to""" - - _test_broadcast_to((4, 1, 32, 32), [4, 8, 32, 32]) - - -def _test_fill(in_shape): - """Use the fill op to create a tensor of ones with non-constant shape.""" - - with tf.Graph().as_default(): - tf.ones(shape=in_shape, dtype="float32") - io_info = {"in_data": in_shape, "in_name": [], "out_name": "ones:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info, use_out_name=False) - - -def test_fill(): - """test tensorflow translator for fill""" - - _test_fill((6, 32, 64, 64)) - - -def _test_pack(axis, shape, **kwargs): - a = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - b = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - - with tf.Graph().as_default(): - tf_a = array_ops.placeholder(shape=shape, dtype="float32", name="pl_a") - tf_b = array_ops.placeholder(shape=shape, dtype="float32", name="pl_b") - tf_c = tf.stack([tf_a, tf_b], axis=axis, **kwargs) - assert tf_c.op.op_def.name == "Pack", "tf.stack() is expected to produce 'Pack' operation" - io_info = {"in_data": [a, b], "in_name": ["pl_a:0", "pl_b:0"], "out_name": "stack:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_pack(): - """test tensorflow translator for pack""" - - _test_pack(3, [3, 2, 1]) - - -def _test_unpack(in_shape, axis): - """test operator Unpack""" - np_data = np.random.uniform(-100, 100, size=in_shape).astype("float32") - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder("float32", in_shape, name="in_data") - tf.unstack(in_data, axis=axis, name="Unpack") - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": "Unpack:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_unpack(): - """test tensorflow translator for unpack""" - - _test_unpack((21, 23, 3), 2) - - -def _test_einsum(equation, *shape_of_input_tensors): - """Test Einsum Op""" - - with tf.Graph().as_default(): - inputs_placeholders = [] - input_data = [] - for idx, shape in enumerate(shape_of_input_tensors): - input_name = f"input_{idx}" - inputs_placeholders.append( - tf.placeholder(shape=shape, dtype="float32", name=input_name) - ) - input_data.append(np.random.normal(size=shape).astype("float32")) - - result = tf.einsum(equation, *inputs_placeholders) - io_info = { - "in_data": input_data, - "in_name": [ph.name for ph in inputs_placeholders], - "out_name": result.name, - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info, use_out_name=False) - - -def test_einsum(): - """test tensorflow translator for einsum""" - - _test_einsum("ij,jk->ik", [2, 3], [3, 5]) # Matmul - _test_einsum("ij,jk", [2, 3], [3, 5]) # Matmul - _test_einsum("i,i->", [2], [2]) # Dot product - _test_einsum("i,j->ij", [3], [5]) # Outer produce - _test_einsum("ij->ji", [2, 3]) # Transpose - _test_einsum("ii->i", [3, 3]) # Diag - _test_einsum("ii", [3, 3]) # Trace of a square matrix - _test_einsum("bij,bjk->bik", [7, 5, 3], [7, 3, 2]) # Batch matmul - - -def _test_pad(input_shape, paddings, mode, **kwargs): - """One iteration of pad operation with given shape""" - - x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=input_shape, dtype="float32") - pad_values = constant_op.constant(paddings) - _ = tf.pad(in_data, paddings=pad_values, mode=mode, **kwargs) - - if mode == "CONSTANT": - if "constant_values" in kwargs: - out_name = "PadV2:0" - else: - out_name = "Pad:0" - else: - out_name = "MirrorPad:0" - - io_info = { - "in_data": x, - "in_name": "Placeholder:0", - "out_name": out_name, - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_pad(): - """test tensorflow translator for pad""" - - _test_pad((2, 3), [[1, 1], [2, 2]], mode="CONSTANT") - _test_pad((2, 3), [[1, 1], [2, 2]], mode="CONSTANT", constant_values=1.0) - - -def test_logical_and(): - """test tensorflow translator for logical_and""" - - with tf.Graph().as_default(): - in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in1") - in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in2") - _ = tf.logical_and(in1, in2, name="out") - in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool") - in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool") - io_info = { - "in_data": [in_data1, in_data2], - "in_name": ["in1:0", "in2:0"], - "out_name": "out:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_logical_or(): - """test tensorflow translator for logical_or""" - - with tf.Graph().as_default(): - in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in1") - in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in2") - _ = tf.logical_or(in1, in2, name="out") - in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool") - in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool") - io_info = { - "in_data": [in_data1, in_data2], - "in_name": ["in1:0", "in2:0"], - "out_name": "out:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_logical_xor(): - """test tensorflow translator for logical_xor""" - - with tf.Graph().as_default(): - in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in1") - in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in2") - _ = tf.logical_xor(in1, in2, name="out") - in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool") - in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool") - io_info = { - "in_data": [in_data1, in_data2], - "in_name": ["in1:0", "in2:0"], - "out_name": "out:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_logical_not(): - """test tensorflow translator for logical_not""" - - with tf.Graph().as_default(): - in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name="in1") - _ = tf.logical_not(in1, name="out") - in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype("bool") - io_info = { - "in_data": [in_data1], - "in_name": ["in1:0"], - "out_name": "out:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_where(): - """test tensorflow translator for where""" - - with tf.Graph().as_default(): - with tf.Session() as _: - input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name="input1") - input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name="input2") - mask = input1 > input2 - tf.where(mask, input1 + 1, input2 * 2) - in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32") - in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32") - io_info = { - "in_data": [in_data1, in_data2], - "in_name": ["input1:0", "input2:0"], - "out_name": "Select:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def _test_transpose(ishape, axes=None): - data = np.random.uniform(size=ishape).astype(np.float32) - - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="transpose_data") - - if axes is None: - tf.transpose(in1) - else: - tf.transpose(in1, perm=axes) - - io_info = { - "in_data": data, - "in_name": "transpose_data:0", - "out_name": "transpose:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def _test_tranapose_axes_input(ishape, axes): - data = np.random.uniform(size=ishape).astype(np.float32) - axes_np = np.array(axes).astype(np.int32) - - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="transpose_data") - - const1 = tf.constant(axes_np, dtype=tf.int32) - - # make axes an input to tf.transpose, but not an input to the graph, - # so it can be extracted with infer_value_simulated - axes = tf.reverse(const1, axis=[-1]) - tf.transpose(in1, axes) - io_info = { - "in_data": data, - "in_name": "transpose_data:0", - "out_name": "transpose:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_transpose(): - """test tensorflow translator for transpose""" - - _test_transpose((2, 3, 4), (1, 2, 0)) - _test_transpose((2, 3, 4)) - _test_tranapose_axes_input((2, 3, 4), (1, 2, 0)) - _test_tranapose_axes_input((2, 3, 4, 5), (3, 0, 1, 2)) - - -def _test_slice_operation_input(input_value, begin_value, size_value): - input_data = np.array(input_value, dtype=np.float32) - with tf.Graph().as_default(): - input_tensor = tf.placeholder(shape=input_data.shape, dtype=input_data.dtype, name="input") - tf.slice(input_tensor, begin_value, size_value, name="slice_output") - io_info = { - "in_data": input_data, - "in_name": "input:0", - "out_name": "slice_output:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_slice(): - """test tensorflow translator for slice""" - - _test_slice_operation_input([1, 1], [0], [2]) - - -def test_ceil(): - """test tensorflow translator for ceil""" - - ishape = (1, 3, 10, 10) - inp_array = np.random.uniform(size=ishape).astype(np.float32) - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) - tf.ceil(in1) - io_info = { - "in_data": inp_array, - "in_name": "Placeholder:0", - "out_name": "Ceil:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_floor(): - """test tensorflow translator for floor""" - - ishape = (1, 3, 10, 10) - inp_array = np.random.uniform(size=ishape).astype(np.float32) - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) - tf.floor(in1) - io_info = { - "in_data": inp_array, - "in_name": "Placeholder:0", - "out_name": "Floor:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_relu(): - """test tensorflow translator for relu""" - - ishape = (1, 3, 10, 10) - inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) - tf.nn.relu(in1) - io_info = { - "in_data": inp_array, - "in_name": "Placeholder:0", - "out_name": "Relu:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_elu(): - """test tensorflow translator for elu""" - - ishape = (1, 3, 10, 10) - inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) - tf.nn.elu(in1) - io_info = { - "in_data": inp_array, - "in_name": "Placeholder:0", - "out_name": "Elu:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_selu(): - """test tensorflow translator for selu""" - - ishape = (1, 3, 10, 10) - inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) - tf.nn.selu(in1) - io_info = { - "in_data": inp_array, - "in_name": "Placeholder:0", - "out_name": "Selu:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_tanh(): - """test tensorflow translator for tanh""" - - ishape = (1, 3, 10, 10) - inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) - tf.nn.tanh(in1) - io_info = { - "in_data": inp_array, - "in_name": "Placeholder:0", - "out_name": "Tanh:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_softmax(): - """test tensorflow translator for softmax""" - - def check_softmax(in_shape, axis, dtype): - np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(dtype, in_shape, name="in_data") - tf.nn.softmax(in_data, axis=axis, name="Softmax") - io_info = { - "in_data": np_data, - "in_name": "in_data:0", - "out_name": "Softmax:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - check_softmax((2, 3, 5), 2, "float32") - check_softmax((2, 3, 5), -1, "float32") - - -def test_round(): - """test tensorflow translator for round""" - - np_data = np.random.uniform(-10, 10, size=(5, 7)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (5, 7), name="in_data") - tf.round(in_data, name="round") - io_info = { - "in_data": np_data, - "in_name": "in_data:0", - "out_name": "round:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_abs(): - """test tensorflow translator for abs""" - - np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (9, 11), name="in_data") - tf.math.abs(in_data, name="abs") - io_info = { - "in_data": np_data, - "in_name": "in_data:0", - "out_name": "abs:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_squared_difference(): - """test tensorflow translator for squared_difference""" - - ishape = (1, 3, 10, 14) - inp_array_a = np.random.uniform(-5, 5, size=ishape).astype(np.float32) - inp_array_b = np.random.uniform(-5, 5, size=ishape).astype(np.float32) - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=inp_array_a.shape, dtype=inp_array_a.dtype, name="in1") - in2 = tf.placeholder(shape=inp_array_b.shape, dtype=inp_array_b.dtype, name="in2") - out = tf.math.squared_difference(in1, in2) - io_info = { - "in_data": [inp_array_a, inp_array_b], - "in_name": [in1.name, in2.name], - "out_name": out.name, - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_sign(): - """test tensorflow translator for sign""" - - np_data = np.random.uniform(-10, 10, size=(5, 7, 11)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data") - tf.sign(in_data, name="sign") - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": "sign:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_square(): - """test tensorflow translator for square""" - - np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") - tf.square(in_data, name="square") - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": "square:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_pow_exp(): - """test tensorflow translator for pow && exp""" - - np_in1 = np.random.uniform(-2, 2, size=(5, 7, 11)).astype(np.float32) - np_in2 = np.random.uniform(-2, 2, size=(5, 7, 11)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in1 = tf.placeholder(tf.float32, (5, 7, 11), name="in1") - in2 = tf.placeholder(tf.float32, (5, 7, 11), name="in2") - in3 = tf.pow(in1, in2, name="pow") - _ = tf.exp(in3, name="exp") - io_info = {"in_data": [np_in1, np_in2], "in_name": ["in1:0", "in2:0"], "out_name": "exp:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_unary(): - """test tensorflow translator for unary""" - - def _test_unary(op, a_min=1, a_max=5, dtype=np.float32): - """test unary operators""" - np_data = np.random.uniform(a_min, a_max, size=(2, 3, 5)).astype(dtype) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(dtype, (2, 3, 5), name="in_data") - out = op(in_data) - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": out.name} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - _test_unary(tf.acos, -1, 1) - _test_unary(tf.asin, -1, 1) - _test_unary(tf.atanh, -1, 1) - _test_unary(tf.sinh) - _test_unary(tf.cosh) - _test_unary(tf.acosh) - _test_unary(tf.asinh) - _test_unary(tf.atan) - _test_unary(tf.sin) - _test_unary(tf.cos) - _test_unary(tf.tan) - _test_unary(tf.tanh) - _test_unary(tf.erf) - _test_unary(tf.log) - - -def test_atan2(): - """test tensorflow translator for atan2""" - - tf.disable_eager_execution() - np_data_1 = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) - np_data_2 = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data_1 = tf.placeholder(tf.float32, (2, 3, 5), name="in_data_1") - in_data_2 = tf.placeholder(tf.float32, (2, 3, 5), name="in_data_2") - tf.atan2(in_data_1, in_data_2, name="atan2") - io_info = { - "in_data": [np_data_1, np_data_2], - "in_name": ["in_data_1:0", "in_data_2:0"], - "out_name": "atan2:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_expm1(): - """test tensorflow translator for expm1""" - - def _test_expm1(shape): - tf.disable_eager_execution() - np_data = np.random.uniform(1, 10, size=shape).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, shape, name="in_data") - tf.expm1(in_data, name="expm1") - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": "expm1:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - _test_expm1([2, 5, 2, 5]) - - -def test_softsign(): - """test tensorflow translator for softsign""" - - def _test_softsign(shape): - tf.disable_eager_execution() - np_data = np.random.uniform(1, 100, size=shape).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, shape, name="in_data") - tf.nn.softsign(in_data, name="softsign") - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": "softsign:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - _test_softsign([2, 5, 2, 5]) - - -def test_rint(): - """test tensorflow translator for rint""" - - def _test_rint(shape): - tf.disable_eager_execution() - np_data = np.random.uniform(-100, 100, size=shape).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, shape, name="in_data") - tf.math.rint(in_data, name="rint") - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": "rint:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - _test_rint([2, 5, 2, 5]) - - -def test_negative(): - """test tensorflow translator for neg""" - - np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data") - tf.negative(in_data, name="negative") - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": "negative:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_log_softmax(): - """test tensorflow translator for log_softmax""" - - np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (9, 11), name="in_data") - tf.math.log_softmax(in_data, name="LogSoftmax") - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": "LogSoftmax:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_softplus(): - """test tensorflow translator for softplus""" - - np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") - tf.nn.softplus(in_data, name="softplus") - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": "softplus:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_rsqrt(): - """test tensorflow translator for rsqrt""" - - np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data") - tf.rsqrt(in_data, name="rsqrt") - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": "rsqrt:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_sqrt(): - """test tensorflow translator for sqrt""" - - np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data") - tf.sqrt(in_data, name="sqrt") - io_info = {"in_data": [np_data], "in_name": ["in_data:0"], "out_name": "sqrt:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_mean(): - """test tensorflow translator for mean""" - - def check_mean(ishape, **kwargs): - inp_array = np.random.uniform(size=ishape).astype(np.float32) - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) - tf.keras.backend.mean(in1, **kwargs) - io_info = {"in_data": inp_array, "in_name": "Placeholder:0", "out_name": "Mean:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - check_mean((10, 8, 16, 32)) - check_mean((10, 8, 16, 32), axis=(2, 3)) - check_mean((10, 8, 16, 32), axis=(1, 2), keepdims=True) - - -def test_reduce(): - """test tensorflow translator for reduce""" - - def _check_op(tf_op, ishape, axis, keepdims): - tf.reset_default_graph() - np_data = np.random.uniform(size=ishape).astype("float32") - if tf_op == tf.math.reduce_prod: - axis = 1 - np_data = np_data.reshape(1, -1) - with tf.Graph().as_default(): - in_data = tf.placeholder(shape=np_data.shape, dtype="float32", name="in_data") - reduce_op = tf_op(in_data, axis=axis, keepdims=keepdims, name="reduce_op") - io_info = {"in_data": np_data, "in_name": "in_data:0", "out_name": reduce_op.name} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - def _test_math_op(op): - _check_op(op, (8, 16, 32), axis=(-1), keepdims=False) - _check_op(op, (1, 8, 8, 3), axis=(2, 3), keepdims=True) - - _test_math_op(tf.math.reduce_max) - _test_math_op(tf.math.reduce_min) - _test_math_op(tf.math.reduce_prod) - _test_math_op(tf.math.reduce_variance) - _test_math_op(tf.math.reduce_std) - _test_math_op(tf.math.reduce_logsumexp) - if package_version.parse(tf.VERSION) >= package_version.parse("1.15.0"): - _test_math_op(tf.math.reduce_euclidean_norm) - - -def _test_rel_op(data, func): - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=data[0].shape, dtype=data[0].dtype, name="in1") - in2 = tf.placeholder(shape=data[1].shape, dtype=data[1].dtype, name="in2") - op = func(in1, in2, name="op") - _ = tf.cast(op, tf.int32, name="out1") - io_info = { - "in_data": [data[0], data[1]], - "in_name": ["in1:0", "in2:0"], - "out_name": "out1:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_rel_ops(): - """test tensorflow translator for relation""" - - t_1 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - t_2 = np.array([[9, 8, 7], [6, 5, 4], [3, 2, 1]]) - _test_rel_op([t_1, t_2], math_ops.less) - _test_rel_op([t_1, t_2], math_ops.greater) - _test_rel_op([t_1, t_2], math_ops.less_equal) - _test_rel_op([t_1, t_2], math_ops.greater_equal) - _test_rel_op([t_1, t_2], math_ops.equal) - _test_rel_op([t_1, t_2], math_ops.not_equal) - - -def _test_expand_dims(data, axis): - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="in1") - out = tf.expand_dims(in1, axis) - io_info = {"in_data": data, "in_name": in1.name, "out_name": out.name} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_expand_dims(): - """test tensorflow translator for expand_dims""" - - _test_expand_dims(np.array([1]), -1) - _test_expand_dims(np.array([[1], [2]]), 1) - - -def test_maximum(): - """test tensorflow translator for maximum""" - - def check_maximum(lh_shape, rh_shape, dtype): - tf.reset_default_graph() - lh_data = np.random.uniform(size=lh_shape).astype(dtype) - rh_data = np.random.uniform(size=rh_shape).astype(dtype) - with tf.Graph().as_default(): - lft_data = tf.placeholder(shape=lh_data.shape, dtype=dtype, name="lft_data") - rgt_data = tf.placeholder(shape=rh_data.shape, dtype=dtype, name="rgt_data") - tf.math.maximum(lft_data, rgt_data, name="maximum") - io_info = { - "in_data": [lh_data, rh_data], - "in_name": ["lft_data:0", "rgt_data:0"], - "out_name": "maximum:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - check_maximum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32") - - -def test_minimum(): - """test tensorflow translator for minimum""" - - def check_minimum(lh_shape, rh_shape, dtype): - tf.reset_default_graph() - lh_data = np.random.uniform(size=lh_shape).astype(dtype) - rh_data = np.random.uniform(size=rh_shape).astype(dtype) - with tf.Graph().as_default(): - lft_data = tf.placeholder(shape=lh_data.shape, dtype=dtype, name="lft_data") - rgt_data = tf.placeholder(shape=rh_data.shape, dtype=dtype, name="rgt_data") - tf.math.minimum(lft_data, rgt_data, name="minimum") - io_info = { - "in_data": [lh_data, rh_data], - "in_name": ["lft_data:0", "rgt_data:0"], - "out_name": "minimum:0", - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - check_minimum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32") - - -def _test_add_n(inputs): - tf.reset_default_graph() - with tf.Graph().as_default(): - temp = [] - for each in inputs: - temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype)) - output = tf.add_n(temp) - io_info = { - "in_data": list(inputs), - "in_name": [each.name for each in temp], - "out_name": output.name, - } - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_add_n(): - """test tensorflow translator for add_n""" - - x_in = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) - y_in = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) - z_in = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) - m_dim, n_dim, o_dim = x_in.astype(np.float32), y_in.astype(np.float32), z_in.astype(np.float32) - in0 = x_in - in1 = [x_in, y_in] - in2 = (x_in, y_in, z_in) - in3 = m_dim - in4 = [m_dim, n_dim] - in5 = (m_dim, n_dim, o_dim) - _test_add_n(in0) - _test_add_n(in1) - _test_add_n(in2) - _test_add_n(in3) - _test_add_n(in4) - _test_add_n(in5) - - -def _test_identityn(data_np_list): - with tf.Graph().as_default(): - data_tensors = [] - data_tensors_name = [] - for index, data_np in enumerate(data_np_list): - tensor_name = f"data_{index}" - data_tensors_name.append(tensor_name + ":0") - data_tensors.append( - tf.placeholder(shape=data_np.shape, dtype=str(data_np.dtype), name=tensor_name) - ) - - output = tf.identity_n(data_tensors) - output_names = [out.name for out in output] - io_info = {"in_data": data_np_list, "in_name": data_tensors_name, "out_name": output_names} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info, use_out_name=False) - - -def test_identityn(): - """test tensorflow translator for identityn""" - - data_np_list = [ - np.array([[1, 1], [0, 3], [0, 1], [2, 0], [3, 1]], dtype=np.int64), - np.array([1, 2, 3, 4, 5], dtype=np.int64), - np.array([5, 6], dtype=np.int64), - ] - _test_identityn(data_np_list) - data_np_list = [ - np.array([[1, 1], [0, 3], [2, 0], [3, 1]], dtype=np.int64), - np.array([1, 2, 3, 4], dtype=np.int64), - np.array([5, 6], dtype=np.int64), - np.array([True, False, True]), - ] - _test_identityn(data_np_list) - - -def _test_infinity(tf_op, name): - """test operator infinity ops""" - - # Only float types are allowed in Tensorflow for isfinite and isinf - # float16 is failing on cuda - tf_dtypes = ["float32", "float64"] # pylint: disable=redefined-outer-name - for tf_dtype in tf_dtypes: - shape = (8, 8) - data = np.random.uniform(size=shape).astype(tf_dtype) - data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.inf - data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.nan - - tf.reset_default_graph() - in_data = tf.placeholder(tf_dtype, shape, name="in_data") - tf_op(in_data, name=name) - io_info = {"in_data": data, "in_name": "in_data:0", "out_name": f"{name}:0"} - graph_def, golden = get_graph_def(**io_info) - verify_model(graph_def, golden, **io_info) - - -def test_infinity(): - """test tensorflow translator for infinity""" - - _test_infinity(tf.is_inf, "isinf") - _test_infinity(tf.is_finite, "isfinite") - _test_infinity(tf.is_nan, "isnan") - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_translate_torch.py b/tests/python/contrib/test_msc/test_translate_torch.py index 97d6e56d4059..22a959d2975d 100644 --- a/tests/python/contrib/test_msc/test_translate_torch.py +++ b/tests/python/contrib/test_msc/test_translate_torch.py @@ -27,13 +27,13 @@ from tvm.contrib.msc.core import utils as msc_utils -def verify_model(torch_model, input_info, via_relax=True): +def verify_model(torch_model, input_info): """Compare torch module results""" torch_datas = [msc_utils.random_data(i, MSCFramework.TORCH) for i in input_info] with torch.no_grad(): golden = torch_model(*torch_datas) - graph, weights = translate.from_torch(torch_model, input_info, via_relax=via_relax) + graph, weights = translate.from_torch(torch_model, input_info) model = codegen.to_torch(graph, weights) with torch.no_grad(): if not graph.get_inputs(): @@ -76,9 +76,8 @@ def forward(self, data): return self.conv(data) input_info = [([1, 3, 10], "float32")] - for via_relax in [True, False]: - verify_model(Conv1D1(), input_info, via_relax) - verify_model(Conv1D2(), input_info, via_relax) + verify_model(Conv1D1(), input_info) + verify_model(Conv1D2(), input_info) def test_conv2d(): @@ -101,9 +100,8 @@ def forward(self, data): return self.conv(data) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Conv2D1(), input_info, via_relax) - verify_model(Conv2D2(), input_info, via_relax) + verify_model(Conv2D1(), input_info) + verify_model(Conv2D2(), input_info) def test_linear(): @@ -130,10 +128,9 @@ def forward(self, x, y): return torch.matmul(x, y) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Dense1(), input_info, via_relax) - verify_model(Dense2(), input_info, via_relax) - verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")], via_relax) + verify_model(Dense1(), input_info) + verify_model(Dense2(), input_info) + verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) def test_bmm(): @@ -144,8 +141,7 @@ def forward(self, x, y): return torch.bmm(x, y) input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")] - for via_relax in [True, False]: - verify_model(BMM(), input_info, via_relax) + verify_model(BMM(), input_info) def test_baddbmm(): @@ -164,9 +160,8 @@ def forward(self, c, x, y): ((4, 128, 256), "float32"), ((4, 256, 512), "float32"), ] - for via_relax in [True, False]: - verify_model(BAddBMM1(), input_info, via_relax) - verify_model(BAddBMM2(), input_info, via_relax) + verify_model(BAddBMM1(), input_info) + verify_model(BAddBMM2(), input_info) def test_relu(): @@ -185,9 +180,8 @@ def forward(self, data): return torch.nn.functional.relu(data) input_info = [([10, 10], "float32")] - for via_relax in [True, False]: - verify_model(ReLU(), input_info, via_relax) - verify_model(ReLU1(), input_info, via_relax) + verify_model(ReLU(), input_info) + verify_model(ReLU1(), input_info) def test_relu6(): @@ -202,8 +196,7 @@ def forward(self, data): return self.relu6(data) input_info = [([10, 10], "float32")] - for via_relax in [True, False]: - verify_model(ReLU6(), input_info, via_relax) + verify_model(ReLU6(), input_info) def test_maxpool2d(): @@ -234,10 +227,9 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(MaxPool2d(), input_info, via_relax) - verify_model(MaxPool2d2(), input_info, via_relax) - verify_model(MaxPool2d3(), input_info, via_relax) + verify_model(MaxPool2d(), input_info) + verify_model(MaxPool2d2(), input_info) + verify_model(MaxPool2d3(), input_info) def test_avgpool2d(): @@ -260,9 +252,8 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(AvgPool2d(), input_info, via_relax) - verify_model(AvgPool2d2(), input_info, via_relax) + verify_model(AvgPool2d(), input_info) + verify_model(AvgPool2d2(), input_info) def test_adaptive_avgpool2d(): @@ -277,8 +268,7 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(AdaptiveAvgPool2d0(), input_info, via_relax) + verify_model(AdaptiveAvgPool2d0(), input_info) def test_flatten(): @@ -293,9 +283,8 @@ def forward(self, data): return self.f(data) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Flatten(), input_info, via_relax) - verify_model(torch.nn.Flatten(2, -1), input_info, via_relax) + verify_model(Flatten(), input_info) + verify_model(torch.nn.Flatten(2, -1), input_info) def test_batchnorm2d(): @@ -310,8 +299,7 @@ def forward(self, data): return self.batchnorm(data) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(BatchNorm2d(), input_info, via_relax) + verify_model(BatchNorm2d(), input_info) def test_embedding(): @@ -325,9 +313,8 @@ def __init__(self): def forward(self, data): return self.embedding(data) - for via_relax in [True, False]: - verify_model(Embedding(), [([4], "int64")], via_relax) - verify_model(Embedding(), [([4, 5], "int64")], via_relax) + verify_model(Embedding(), [([4], "int64")]) + verify_model(Embedding(), [([4, 5], "int64")]) def test_layernorm(): @@ -374,10 +361,9 @@ def forward(self, logits, targets): return self.loss(logits, targets) input_info = [([3, 2], "float32"), ([3], "int64")] - for via_relax in [True, False]: - verify_model(CrossEntropy1(), input_info, via_relax) - verify_model(CrossEntropy2(), input_info, via_relax) - verify_model(CrossEntropy3(), input_info, via_relax) + verify_model(CrossEntropy1(), input_info) + verify_model(CrossEntropy2(), input_info) + verify_model(CrossEntropy3(), input_info) def test_silu(): @@ -396,9 +382,8 @@ def forward(self, data): return torch.nn.functional.silu(data) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(SiLU(), input_info, via_relax) - verify_model(SiLU2(), input_info, via_relax) + verify_model(SiLU(), input_info) + verify_model(SiLU2(), input_info) def test_groupnorm(): @@ -413,8 +398,7 @@ def forward(self, data): return self.groupnorm(data) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(GroupNorm(), input_info, via_relax) + verify_model(GroupNorm(), input_info) def test_softmax(): @@ -429,8 +413,7 @@ def forward(self, data): return self.softmax(data) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Softmax(), input_info, via_relax) + verify_model(Softmax(), input_info) def test_binary(): @@ -448,9 +431,8 @@ class Add2(Module): def forward(self, lhs): return lhs + 1.0 - for via_relax in [True, False]: - verify_model(Add1(), input_info1, via_relax) - verify_model(Add2(), input_info2, via_relax) + verify_model(Add1(), input_info1) + verify_model(Add2(), input_info2) # Sub class Sub1(Module): @@ -461,9 +443,8 @@ class Sub2(Module): def forward(self, lhs): return lhs - 1.0 - for via_relax in [True, False]: - verify_model(Sub1(), input_info1, via_relax) - verify_model(Sub2(), input_info2, via_relax) + verify_model(Sub1(), input_info1) + verify_model(Sub2(), input_info2) # Mul class Mul1(Module): @@ -474,9 +455,8 @@ class Mul2(Module): def forward(self, lhs): return lhs * 1.0 - for via_relax in [True, False]: - verify_model(Mul1(), input_info1, via_relax) - verify_model(Mul2(), input_info2, via_relax) + verify_model(Mul1(), input_info1) + verify_model(Mul2(), input_info2) # True div class TrueDiv1(Module): @@ -487,9 +467,8 @@ class TrueDiv2(Module): def forward(self, lhs): return lhs / 1.0 - for via_relax in [True, False]: - verify_model(TrueDiv1(), input_info1, via_relax) - verify_model(TrueDiv2(), input_info2, via_relax) + verify_model(TrueDiv1(), input_info1) + verify_model(TrueDiv2(), input_info2) # Floor div class FloorDiv1(Module): @@ -500,9 +479,8 @@ class FloorDiv2(Module): def forward(self, lhs): return lhs // 1.0 - for via_relax in [True, False]: - verify_model(FloorDiv1(), input_info1, via_relax) - verify_model(FloorDiv2(), input_info2, via_relax) + verify_model(FloorDiv1(), input_info1) + verify_model(FloorDiv2(), input_info2) # Power class Power1(Module): @@ -513,9 +491,8 @@ class Power2(Module): def forward(self, lhs): return lhs**1.0 - for via_relax in [True, False]: - verify_model(Power1(), input_info1, via_relax) - verify_model(Power2(), input_info2, via_relax) + verify_model(Power1(), input_info1) + verify_model(Power2(), input_info2) # LT class LT1(Module): @@ -526,9 +503,8 @@ class LT2(Module): def forward(self, lhs): return lhs < 1.0 - for via_relax in [True, False]: - verify_model(LT1(), input_info1, via_relax) - verify_model(LT2(), input_info2, via_relax) + verify_model(LT1(), input_info1) + verify_model(LT2(), input_info2) def test_size(): @@ -554,9 +530,8 @@ def forward(self, data): return data.squeeze() input_info = [([3, 1, 4, 1], "float32")] - for via_relax in [True, False]: - verify_model(Squeeze1(), input_info, via_relax) - verify_model(Squeeze2(), input_info, via_relax) + verify_model(Squeeze1(), input_info) + verify_model(Squeeze2(), input_info) def test_unsqueeze(): @@ -571,9 +546,8 @@ def forward(self, data): return data.unsqueeze(-1) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Unsqueeze1(), input_info, via_relax) - verify_model(Unsqueeze2(), input_info, via_relax) + verify_model(Unsqueeze1(), input_info) + verify_model(Unsqueeze2(), input_info) def test_getattr(): @@ -599,9 +573,8 @@ class Slice2(Module): def forward(self, x): return x[:, None, None, :, None] - for via_relax in [True, False]: - verify_model(Slice1(), [([1, 3, 10, 10], "float32")], via_relax) - verify_model(Slice2(), [([8, 16], "float32")], via_relax) + verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) + verify_model(Slice2(), [([8, 16], "float32")]) def test_unary(): @@ -614,48 +587,42 @@ class Sin(Module): def forward(self, data): return torch.sin(data) - for via_relax in [True, False]: - verify_model(Sin(), input_info, via_relax) + verify_model(Sin(), input_info) # cos class Cos(Module): def forward(self, data): return torch.cos(data) - for via_relax in [True, False]: - verify_model(Cos(), input_info, via_relax) + verify_model(Cos(), input_info) # exp class Exp(Module): def forward(self, data): return torch.exp(data) - for via_relax in [True, False]: - verify_model(Exp(), input_info, via_relax) + verify_model(Exp(), input_info) # sqrt class Sqrt(Module): def forward(self, data): return torch.sqrt(data) - for via_relax in [True, False]: - verify_model(Sqrt(), input_info, via_relax) + verify_model(Sqrt(), input_info) # sigmoid class Sigmoid(Module): def forward(self, data): return torch.sigmoid(data) - for via_relax in [True, False]: - verify_model(Sigmoid(), input_info, via_relax) + verify_model(Sigmoid(), input_info) # round class Round(Module): def forward(self, data): return torch.round(data) - for via_relax in [True, False]: - verify_model(Round(), input_info, via_relax) + verify_model(Round(), input_info) def test_gelu(): @@ -666,8 +633,7 @@ def forward(self, data): return torch.nn.functional.gelu(data) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Gelu(), input_info, via_relax) + verify_model(Gelu(), input_info) def test_tanh(): @@ -678,8 +644,7 @@ def forward(self, data): return torch.tanh(data) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Tanh(), input_info, via_relax) + verify_model(Tanh(), input_info) def test_clamp(): @@ -690,8 +655,7 @@ def forward(self, data): return torch.clamp(data, min=0.1, max=0.5) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Clamp(), input_info, via_relax) + verify_model(Clamp(), input_info) def test_interpolate(): @@ -702,8 +666,7 @@ def forward(self, data): return torch.nn.functional.interpolate(data, (5, 5)) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Interpolate(), input_info, via_relax) + verify_model(Interpolate(), input_info) def test_addmm(): @@ -718,8 +681,7 @@ def forward(self, x_1, x_2, x_3): ([10, 10], "float32"), ([10, 10], "float32"), ] - for via_relax in [True, False]: - verify_model(Addmm(), input_info, via_relax) + verify_model(Addmm(), input_info) def test_split(): @@ -734,9 +696,8 @@ def forward(self, data): return torch.split(data, [1, 2], dim=1) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Split1(), input_info, via_relax) - verify_model(Split2(), input_info, via_relax) + verify_model(Split1(), input_info) + verify_model(Split2(), input_info) def test_unbind(): @@ -751,9 +712,8 @@ def forward(self, data): return torch.unbind(data, dim=1) input_info = [([3, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Unbind1(), input_info, via_relax) - verify_model(Unbind2(), input_info, via_relax) + verify_model(Unbind1(), input_info) + verify_model(Unbind2(), input_info) def test_cumsum(): @@ -764,8 +724,7 @@ def forward(self, data): return torch.cumsum(data, dim=1, dtype=torch.int32) input_info = [([1, 2, 3, 4], "float32")] - for via_relax in [True, False]: - verify_model(Cumsum(), input_info, via_relax) + verify_model(Cumsum(), input_info) def test_chunk(): @@ -776,8 +735,7 @@ def forward(self, data): return torch.chunk(data, 3, dim=1) input_info = [([1, 3, 10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Chunk(), input_info, via_relax) + verify_model(Chunk(), input_info) def test_inplace_fill(): @@ -788,8 +746,7 @@ def forward(self, data): data.fill_(1.5) return data - for via_relax in [True, False]: - verify_model(InplaceFill(), [([10, 10], "float32")], via_relax) + verify_model(InplaceFill(), [([10, 10], "float32")]) def test_arange(): @@ -816,9 +773,8 @@ def forward(self, data): return data input_info = [([10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Tril(), input_info, via_relax) - verify_model(InplaceTril(), input_info, via_relax) + verify_model(Tril(), input_info) + verify_model(InplaceTril(), input_info) def test_triu(): @@ -834,9 +790,8 @@ def forward(self, data): return data input_info = [([10, 10], "float32")] - for via_relax in [True, False]: - verify_model(Triu(), input_info, via_relax) - verify_model(InplaceTriu(), input_info, via_relax) + verify_model(Triu(), input_info) + verify_model(InplaceTriu(), input_info) def test_new_ones(): @@ -847,8 +802,7 @@ def forward(self, x): return x.new_ones(1, 2, 3) input_info = [([1, 2, 3], "float32")] - for via_relax in [True, False]: - verify_model(NewOnes(), input_info, via_relax) + verify_model(NewOnes(), input_info) def test_expand(): @@ -863,9 +817,8 @@ def forward(self, x): return x.expand(4, -1, -1, 4) input_info = [([1, 2, 3, 4], "float32")] - for via_relax in [True, False]: - verify_model(Expand1(), input_info, via_relax) - verify_model(Expand2(), input_info, via_relax) + verify_model(Expand1(), input_info) + verify_model(Expand2(), input_info) def test_reduce(): @@ -887,10 +840,9 @@ def forward(self, x): return torch.min(x) input_info = [([1, 2, 3, 4], "float32")] - for via_relax in [True, False]: - verify_model(Sum(), input_info, via_relax) - verify_model(Max(), input_info, False) - verify_model(Min(), input_info, False) + verify_model(Sum(), input_info) + verify_model(Max(), input_info) + verify_model(Min(), input_info) def test_datatype(): @@ -903,24 +855,21 @@ class ToFloat(Module): def forward(self, x): return x.float() - for via_relax in [True, False]: - verify_model(ToFloat(), input_info, via_relax) + verify_model(ToFloat(), input_info) # half class ToHalf(Module): def forward(self, x): return x.half() - for via_relax in [True, False]: - verify_model(ToHalf(), input_info, via_relax) + verify_model(ToHalf(), input_info) # type class Type(Module): def forward(self, x): return x.type(torch.float32) - for via_relax in [True, False]: - verify_model(Type(), input_info, via_relax) + verify_model(Type(), input_info) def test_permute(): @@ -931,8 +880,7 @@ def forward(self, x): return x.permute(0, 3, 2, 1) input_info = [([1, 2, 3, 4], "float32")] - for via_relax in [True, False]: - verify_model(Permute(), input_info, via_relax) + verify_model(Permute(), input_info) def test_reshape(): @@ -943,8 +891,7 @@ def forward(self, x): return x.reshape(2, 12) input_info = [([1, 2, 3, 4], "float32")] - for via_relax in [True, False]: - verify_model(Reshape(), input_info, via_relax) + verify_model(Reshape(), input_info) def test_transpose(): @@ -955,8 +902,7 @@ def forward(self, x): return x.transpose(1, 3) input_info = [([1, 2, 3, 4], "float32")] - for via_relax in [True, False]: - verify_model(Transpose(), input_info, via_relax) + verify_model(Transpose(), input_info) def test_view(): @@ -967,8 +913,7 @@ def forward(self, x): return x.view(2, 12) input_info = [([1, 2, 3, 4], "float32")] - for via_relax in [True, False]: - verify_model(View(), input_info, via_relax) + verify_model(View(), input_info) def test_keep_params(): @@ -982,8 +927,7 @@ def __init__(self): def forward(self, data): return self.conv(data) - for via_relax in [True, False]: - verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")], via_relax) + verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")]) def test_unwrap_unit_return_tuple(): @@ -993,8 +937,7 @@ class Identity(Module): def forward(self, x): return (x,) - for via_relax in [True, False]: - verify_model(Identity(), [([256, 256], "float32")], via_relax) + verify_model(Identity(), [([256, 256], "float32")]) def test_no_bind_return_tuple(): @@ -1005,8 +948,7 @@ def forward(self, x, y): return (x, y) input_info = [([256, 256], "float32"), ([256, 256], "float32")] - for via_relax in [True, False]: - verify_model(Identity(), input_info, via_relax) + verify_model(Identity(), input_info) def test_argmax(): @@ -1020,9 +962,8 @@ class Argmax2(Module): def forward(self, data): return torch.argmax(data, dim=-1, keepdim=True) - for via_relax in [True, False]: - verify_model(Argmax1(), [([256, 256], "float32")], via_relax) - verify_model(Argmax2(), [([256, 256], "float32")], via_relax) + verify_model(Argmax1(), [([256, 256], "float32")]) + verify_model(Argmax2(), [([256, 256], "float32")]) def test_argmin(): @@ -1051,9 +992,8 @@ class To2(Module): def forward(self, data): return data.to("cpu") - for via_relax in [True, False]: - verify_model(To1(), [([256, 256], "float32")], via_relax) - verify_model(To2(), [([256, 256], "float32")], via_relax) + verify_model(To1(), [([256, 256], "float32")]) + verify_model(To2(), [([256, 256], "float32")]) def test_mean(): @@ -1067,9 +1007,8 @@ class MeanKeepDim(Module): def forward(self, data): return data.mean(-1, keepdim=True) - for via_relax in [True, False]: - verify_model(Mean(), [([256, 256], "float32")], via_relax) - verify_model(MeanKeepDim(), [([256, 256], "float32")], via_relax) + verify_model(Mean(), [([256, 256], "float32")]) + verify_model(MeanKeepDim(), [([256, 256], "float32")]) def test_rsqrt(): @@ -1079,8 +1018,7 @@ class Rsqrt(Module): def forward(self, data): return torch.rsqrt(data) - for via_relax in [True, False]: - verify_model(Rsqrt(), [([256, 256], "float32")], via_relax) + verify_model(Rsqrt(), [([256, 256], "float32")]) def test_neg(): @@ -1090,8 +1028,7 @@ class Neg(Module): def forward(self, data): return -data - for via_relax in [True, False]: - verify_model(Neg(), [([256, 256], "float32")], via_relax) + verify_model(Neg(), [([256, 256], "float32")]) def test_max(): @@ -1101,8 +1038,7 @@ class Max(Module): def forward(self, x, y): return torch.max(x, y) - for via_relax in [True, False]: - verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], via_relax) + verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) def test_cat(): @@ -1123,9 +1059,8 @@ def forward(self, data): ([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32"), ] - for via_relax in [True, False]: - verify_model(Cat1(), input_info, via_relax) - verify_model(Cat2(), [([1, 3, 10, 10], "float32")], via_relax) + verify_model(Cat1(), input_info) + verify_model(Cat2(), [([1, 3, 10, 10], "float32")]) def test_stack(): @@ -1146,9 +1081,8 @@ def forward(self, data): ([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32"), ] - for via_relax in [True, False]: - verify_model(Stack1(), input_info, via_relax) - verify_model(Stack2(), [([1, 3, 10, 10], "float32")], via_relax) + verify_model(Stack1(), input_info) + verify_model(Stack2(), [([1, 3, 10, 10], "float32")]) def test_scatter(): @@ -1166,11 +1100,8 @@ class Scatter2(Module): def forward(self, data, index, src): return data.scatter(0, index, src) - for via_relax in [True, False]: - verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")], via_relax) - verify_model( - Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")], via_relax - ) + verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")]) + verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")]) def test_masked_scatter(): From 3fac2232ae4cd069436f64380182db9940ba06bd Mon Sep 17 00:00:00 2001 From: "meng.tong" Date: Wed, 26 Feb 2025 20:18:16 +0800 Subject: [PATCH 2/5] remove relay tests --- .../contrib/test_msc/test_graph_build.py | 25 --------------- .../contrib/test_msc/test_translate_relax.py | 16 ---------- .../contrib/test_msc/test_translate_torch.py | 32 ++----------------- tests/scripts/task_config_build_cpu.sh | 3 +- tests/scripts/task_config_build_gpu.sh | 3 +- tests/scripts/unity/task_python_relax.sh | 2 +- 6 files changed, 5 insertions(+), 76 deletions(-) diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 40f61eaf8291..5396b5e106a6 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -2557,31 +2557,6 @@ def forward(self, data, mask, src): ) -def test_put(): - """test graph builder for index_put""" - - class IndexPut(Module): - def __init__(self): - super().__init__() - self.index = msc_utils.random_data([(5), "int64"], MSCFramework.TORCH, max_val=5) - - def forward(self, data, src): - data[self.index] = src - return data - - expected = { - "inputs": [ - {"name": "input0", "shape": [10, 20], "dtype": "float32", "layout": ""}, - {"name": "input1", "shape": [5, 20], "dtype": "float32", "layout": ""}, - ], - "outputs": [{"name": "scatter_nd", "shape": [10, 20], "dtype": "float32", "layout": ""}], - "nodes": {"total": 4, "input": 2, "constant": 1, "scatter_nd": 1}, - } - - input_info = [([10, 20], "float32"), ([5, 20], "float32")] - verify_model(IndexPut(), input_info, expected) - - @pytest.mark.parametrize("dynamic", [True, False]) def test_attention(dynamic): """test graph builder for attention""" diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index 7ed18574e814..2c1af75f9a33 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -1216,22 +1216,6 @@ def forward(self, data, src): verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")]) -def test_put(): - """test relax translator for index_put""" - - class IndexPut(Module): - def __init__(self): - super().__init__() - self.index = msc_utils.random_data([(5), "int64"], MSCFramework.TORCH, max_val=5) - - def forward(self, data, src): - data[self.index] = src - return data - - input_info = [([10, 20], "float32"), ([5, 20], "float32")] - verify_model(IndexPut(), input_info) - - def test_attention(): """test relax translator for attention""" diff --git a/tests/python/contrib/test_msc/test_translate_torch.py b/tests/python/contrib/test_msc/test_translate_torch.py index 22a959d2975d..b766081ea58b 100644 --- a/tests/python/contrib/test_msc/test_translate_torch.py +++ b/tests/python/contrib/test_msc/test_translate_torch.py @@ -829,20 +829,8 @@ class Sum(Module): def forward(self, x): return torch.sum(x, (2, 1)) - # max - class Max(Module): - def forward(self, x): - return torch.max(x) - - # min - class Min(Module): - def forward(self, x): - return torch.min(x) - input_info = [([1, 2, 3, 4], "float32")] verify_model(Sum(), input_info) - verify_model(Max(), input_info) - verify_model(Min(), input_info) def test_datatype(): @@ -1123,24 +1111,8 @@ def __init__(self): def forward(self, data, src): return data.masked_scatter(self.mask, src) - verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")], True) - verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")], True) - - -def test_put(): - """test torch translator for index_put""" - - class IndexPut(Module): - def __init__(self): - super().__init__() - self.index = msc_utils.random_data([(5), "int64"], MSCFramework.TORCH, max_val=5) - - def forward(self, data, src): - data[self.index] = src - return data - - input_info = [([10, 20], "float32"), ([5, 20], "float32")] - verify_model(IndexPut(), input_info, False) + verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")]) + verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")]) def test_attention(): diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 20bf6100f4a1..cd84f5ded46a 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -46,5 +46,4 @@ echo set\(BACKTRACE_ON_SEGFAULT ON\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake echo set\(USE_UMA ON\) >> config.cmake echo set\(SUMMARIZE ON\) >> config.cmake -# Temporary disable MSC -# echo set\(USE_MSC ON\) >> config.cmake +echo set\(USE_MSC ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index 476886782620..74bb702a8b08 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -43,7 +43,6 @@ echo set\(SUMMARIZE ON\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake echo set\(USE_PIPELINE_EXECUTOR ON\) >> config.cmake echo set\(USE_CUTLASS ON\) >> config.cmake -# Temporary disable MSC -# echo set\(USE_MSC ON\) >> config.cmake +echo set\(USE_MSC ON\) >> config.cmake echo set\(CMAKE_CUDA_ARCHITECTURES 75\) >> config.cmake echo set\(USE_CLML ON\) >> config.cmake diff --git a/tests/scripts/unity/task_python_relax.sh b/tests/scripts/unity/task_python_relax.sh index 5eb2a9e4201e..5a72254924e1 100755 --- a/tests/scripts/unity/task_python_relax.sh +++ b/tests/scripts/unity/task_python_relax.sh @@ -38,7 +38,7 @@ TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight # python3 ./apps/relax_examples/resnet.py # Test for MSC -# pytest tests/python/contrib/test_msc +pytest tests/python/contrib/test_msc # Test for OpenCLML pytest tests/python/relax/backend/clml/ From cf4db1011106ff8a0bf05c5655cc245cfd0b6dff Mon Sep 17 00:00:00 2001 From: "meng.tong" Date: Sat, 1 Mar 2025 07:23:51 +0800 Subject: [PATCH 3/5] remove useless --- .../python/contrib/test_msc/test_transform.py | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/tests/python/contrib/test_msc/test_transform.py b/tests/python/contrib/test_msc/test_transform.py index 0983be958946..0d9c29837175 100644 --- a/tests/python/contrib/test_msc/test_transform.py +++ b/tests/python/contrib/test_msc/test_transform.py @@ -21,10 +21,6 @@ from tvm.relax.frontend.torch import from_fx from tvm.relax import PyExprVisitor -from tvm.relay import testing -from tvm.relay.expr_functor import ExprVisitor -from tvm.relay.build_module import bind_params_by_name - from tvm.contrib.msc.core import transform as msc_transform from tvm.contrib.msc.core import utils as msc_utils @@ -73,35 +69,6 @@ def visit_constant_(self, op) -> None: RelaxLayoutChecker().check(mod) -def test_relay_name(): - """Test SetExprName for relay""" - - class RelayNameChecker(ExprVisitor): - """Check if name as span attribute is setted.""" - - def check(self, expr): - self._missing_exprs = [] - super().visit(expr) - assert len(self._missing_exprs) == 0, "Missing {} names".format( - len(self._missing_exprs) - ) - - def visit_constant(self, expr): - super().visit_constant(expr) - if not msc_utils.get_expr_name(expr): - self._missing_exprs.append(expr) - - def visit_call(self, expr): - super().visit_call(expr) - if not msc_utils.get_expr_name(expr): - self._missing_exprs.append(expr) - - mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32") - mod["main"] = bind_params_by_name(mod["main"], params) - mod = msc_transform.SetExprName(as_relax=False)(mod) - RelayNameChecker().check(mod["main"]) - - def test_relax(): """Test SetExprName for relax""" From 6376888ed61100f35f633680a4f65b4872c07c35 Mon Sep 17 00:00:00 2001 From: "meng.tong" Date: Sun, 2 Mar 2025 07:01:55 +0800 Subject: [PATCH 4/5] update torch and test --- .../msc/framework/torch/codegen/codegen.py | 2 +- .../msc/framework/torch/runtime/runner.py | 7 +-- tests/python/contrib/test_msc/test_runner.py | 50 ------------------- 3 files changed, 2 insertions(+), 57 deletions(-) diff --git a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py b/python/tvm/contrib/msc/framework/torch/codegen/codegen.py index d4aeabb10a1b..bb3082e5d8f8 100644 --- a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/torch/codegen/codegen.py @@ -71,7 +71,7 @@ def _save_weights(folder: msc_utils.MSCDirectory): def _bind_weights(model: torch.nn.Module, folder: msc_utils.MSCDirectory) -> torch.nn.Module: if weights: - state_dict = torch.load(folder.relpath(graph.name + ".pth")) + state_dict = torch.load(folder.relpath(graph.name + ".pth"), weights_only=False) model.load_state_dict(state_dict) return model diff --git a/python/tvm/contrib/msc/framework/torch/runtime/runner.py b/python/tvm/contrib/msc/framework/torch/runtime/runner.py index 27773cecdc6d..a4d37d08f521 100644 --- a/python/tvm/contrib/msc/framework/torch/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/torch/runtime/runner.py @@ -319,12 +319,7 @@ def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: config["parse"]["parser"] = from_torch parse_config = config["parse"].get("parse_config", {}) parse_config.update( - { - "input_info": [ - [i[1], "float" if len(i) < 2 else i[2]] for i in config["inputs"] - ], - "input_names": [i[0] for i in config["inputs"]], - } + {"input_info": [[i[1], "float" if len(i) < 2 else i[2]] for i in config["inputs"]]} ) config["parse"]["parse_config"] = parse_config return config diff --git a/tests/python/contrib/test_msc/test_runner.py b/tests/python/contrib/test_msc/test_runner.py index c75974051d4b..14c872beddc1 100644 --- a/tests/python/contrib/test_msc/test_runner.py +++ b/tests/python/contrib/test_msc/test_runner.py @@ -22,15 +22,12 @@ import torch from torch import fx -from tvm.contrib.msc.framework.tensorflow import tf_v1 import tvm.testing from tvm.relax.frontend.torch import from_fx from tvm.contrib.msc.framework.tvm.runtime import TVMRunner from tvm.contrib.msc.framework.torch.runtime import TorchRunner from tvm.contrib.msc.framework.tensorrt.runtime import TensorRTRunner -from tvm.contrib.msc.framework.tensorflow.frontend import from_tensorflow -from tvm.contrib.msc.framework.tensorflow.runtime import TensorflowRunner from tvm.contrib.msc.core import utils as msc_utils requires_tensorrt = pytest.mark.skipif( @@ -57,27 +54,6 @@ def _get_torch_model(name, training=False): return None -def _get_tf_graph(): - """Get tensorflow graphdef""" - - # pylint: disable=import-outside-toplevel - try: - import tvm.relay.testing.tf as tf_testing - - tf_graph = tf_v1.Graph() - with tf_graph.as_default(): - graph_def = tf_testing.get_workload( - "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz", - "mobilenet_v2_1.4_224_frozen.pb", - ) - # Call the utility to import the graph definition into default graph. - graph_def = tf_testing.ProcessGraphDefParam(graph_def) - return tf_graph, graph_def - except: # pylint: disable=bare-except - print("please install tensorflow package") - return None, None - - def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1): """Test runner from torch model""" @@ -142,31 +118,5 @@ def test_tensorrt_runner(): _test_from_torch(TensorRTRunner, "cuda", atol=1e-1, rtol=1e-1) -@pytest.mark.skip(reason="Failed due to tf and tflite upgrade.") -def test_tensorflow_runner(): - """Test runner from tf graph""" - - tf_graph, graph_def = _get_tf_graph() - if tf_graph and graph_def: - path = "test_runner_tf" - workspace = msc_utils.set_workspace(msc_utils.msc_dir(path, keep_history=False)) - log_path = workspace.relpath("MSC_LOG", keep_history=False) - msc_utils.set_global_logger("critical", log_path) - data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32") - out_name = "MobilenetV2/Predictions/Reshape_1:0" - # get golden - with tf_v1.Session(graph=tf_graph) as sess: - golden = sess.run([out_name], {"input:0": data}) - # get outputs - shape_dict = {"input": data.shape} - mod, _ = from_tensorflow(graph_def, shape_dict, [out_name], as_msc=False) - runner = TensorflowRunner(mod) - runner.build() - outputs = runner.run([data], ret_type="list") - workspace.destory() - for gol_r, out_r in zip(golden, outputs): - tvm.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=1e-3, rtol=1e-3) - - if __name__ == "__main__": tvm.testing.main() From 666694e1961f4d0b65885f60706f4d36fa8974fa Mon Sep 17 00:00:00 2001 From: "meng.tong" Date: Sun, 2 Mar 2025 21:16:13 +0800 Subject: [PATCH 5/5] remove dynamic test --- .../python/contrib/test_msc/test_pipeline.py | 81 +------------------ 1 file changed, 3 insertions(+), 78 deletions(-) diff --git a/tests/python/contrib/test_msc/test_pipeline.py b/tests/python/contrib/test_msc/test_pipeline.py index b55667004c67..b892b914a96e 100644 --- a/tests/python/contrib/test_msc/test_pipeline.py +++ b/tests/python/contrib/test_msc/test_pipeline.py @@ -73,28 +73,6 @@ def _get_torch_model(name, training=False): return None -def _get_tf_graph(): - """Get graph from tensorflow""" - - # pylint: disable=import-outside-toplevel - try: - from tvm.contrib.msc.framework.tensorflow import tf_v1 - import tvm.relay.testing.tf as tf_testing - - tf_graph = tf_v1.Graph() - with tf_graph.as_default(): - graph_def = tf_testing.get_workload( - "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz", - "mobilenet_v2_1.4_224_frozen.pb", - ) - # Call the utility to import the graph definition into default graph. - graph_def = tf_testing.ProcessGraphDefParam(graph_def) - return graph_def - except: # pylint: disable=bare-except - print("please install tensorflow package") - return None - - def _check_pipeline(pipeline, expected_info, dynamic=False): """Check the pipeline results""" @@ -136,25 +114,7 @@ def _test_from_torch( _check_pipeline(pipeline, expected_info, dynamic) -def _test_from_tf(compile_type, expected_info, atol=1e-2, rtol=1e-2): - graphdef = _get_tf_graph() - if graphdef: - config = _get_config( - MSCFramework.TENSORFLOW, - compile_type, - inputs=[["input", [1, 224, 224, 3], "float32"]], - outputs=["MobilenetV2/Predictions/Reshape_1:0"], - atol=atol, - rtol=rtol, - ) - config["compile"]["profile"]["check"]["err_rate"] = -1 - manager = MSCManager(graphdef, config) - manager.run_pipe() - _check_pipeline(manager, expected_info) - - -@pytest.mark.skip(reason="Failed due to tf and tflite upgrade.") -@pytest.mark.parametrize("dynamic", [False, True]) +@pytest.mark.parametrize("dynamic", [False]) def test_tvm_pipeline(dynamic): """Test pipeline for tvm""" @@ -207,10 +167,9 @@ def test_tvm_pipeline(dynamic): "nn.softmax": 1, }, } - _test_from_tf(MSCFramework.TVM, model_info) -@pytest.mark.parametrize("dynamic", [False, True]) +@pytest.mark.parametrize("dynamic", [False]) def test_torch_pipeline(dynamic): """Test pipeline for torch""" @@ -236,42 +195,8 @@ def test_torch_pipeline(dynamic): _test_from_torch(MSCFramework.TORCH, model_info, training=False, dynamic=dynamic) -@pytest.mark.skip(reason="Failed due to tf and tflite upgrade.") -def test_tensorflow_pipeline(): - """Test manager for tensorflow""" - - model_info = { - "inputs": [ - {"name": "input", "shape": [1, 224, 224, 3], "dtype": "float32", "layout": "NHWC"} - ], - "outputs": [ - { - "name": "MobilenetV2/Predictions/Reshape_1:0", - "shape": [1, 1001], - "dtype": "float32", - "layout": "NC", - } - ], - "nodes": { - "total": 138, - "input": 1, - "msc.conv2d_bias": 36, - "clip": 35, - "nn.conv2d": 17, - "nn.batch_norm": 17, - "get_item": 17, - "add": 10, - "nn.avg_pool2d": 1, - "squeeze": 1, - "reshape": 2, - "nn.softmax": 1, - }, - } - _test_from_tf(MSCFramework.TENSORFLOW, model_info) - - @requires_tensorrt -@pytest.mark.parametrize("dynamic", [False, True]) +@pytest.mark.parametrize("dynamic", [False]) def test_tensorrt_pipeline(dynamic): """Test pipeline for tensorrt"""