Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 0 additions & 43 deletions python/tvm/contrib/msc/core/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from tvm.relax import PyExprVisitor
from tvm.contrib.msc.core import transform as msc_transform
from tvm.contrib.msc.core.ir import MSCGraph, MSCTensor
from tvm.contrib.msc.core.frontend import from_relay
from tvm.contrib.msc.core import utils as msc_utils


Expand Down Expand Up @@ -216,45 +215,3 @@ def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModul
if plugin:
model_args = model_args + [plugin]
return codegen.load(model_args, pre_load=_save_weights, post_load=_post_proc)


def relay_to_relax(
relay_mod: tvm.IRModule,
params: Optional[Dict[str, tvm.nd.array]] = None,
trans_config: Optional[Dict[str, str]] = None,
build_config: Optional[Dict[str, str]] = None,
opt_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
) -> tvm.IRModule:
"""Change relay IRModule to relax MSCGraph.

Parameters
----------
relay_mod: IRModule
The IRModule of relay.
params: dict of <string:tvm.ndarray>
The parameters of the IRModule.
trans_config: dict
The config for transform IRModule.
build_config: dict
The config for build MSCGraph.
opt_config: dict
The config for optimize the relay before translate.
build_folder: MSCDirectory
The folder for saving scripts and datas.

Returns
-------
relax_mod: IRModule
The IRModule of relax.
"""

graph, weights = from_relay(
relay_mod,
params,
trans_config=trans_config,
build_config=build_config,
opt_config=opt_config,
)

return to_relax(graph, weights, codegen_config={"from_relay": True}, build_folder=build_folder)
114 changes: 0 additions & 114 deletions python/tvm/contrib/msc/core/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
from tvm.relax.transform import BindParams
from tvm.relax import PyExprVisitor
from tvm.relax.backend.pattern_registry import get_patterns_with_prefix
from tvm.relay.expr_functor import ExprVisitor
from tvm.relay.build_module import bind_params_by_name
from tvm.relay import dataflow_pattern as relay_pattern
from tvm.contrib.msc.core import transform as msc_transform
from tvm.contrib.msc.core import _ffi_api
from tvm.contrib.msc.core import utils as msc_utils
Expand Down Expand Up @@ -170,117 +167,6 @@ def from_relax(
return graph, normalize_weights(t_weights, graph)


def get_relay_patterns(
mod: tvm.IRModule,
entry_name: str = "main",
) -> List[Tuple[str, relay_pattern.DFPattern, callable]]:
"""Filter relay patterns based on mod.

Parameters
----------
mod: IRModule
The IRModule of relay.
entry_name: str
The entry name.

Returns
-------
patterns: list
The useful patterns for relay
"""

class OpExtractor(ExprVisitor):
"""Extract ops from expr."""

def extract(self, expr):
self._optypes = set()
super().visit(expr)
return self._optypes

def visit_call(self, expr):
super().visit_call(expr)
if isinstance(expr.op, tvm.ir.Op):
self._optypes.add(expr.op.name)

op_names = OpExtractor().extract(mod[entry_name])
skip_tags, patterns = set(), list(tvm.relay.op.contrib.get_pattern_table("msc"))
if "nn.conv1d" not in op_names or "add" not in op_names:
skip_tags.add("msc.conv1d_bias")
if "nn.conv2d" not in op_names or "add" not in op_names:
skip_tags.add("msc.conv2d_bias")
if "nn.batch_matmul" not in op_names or "add" not in op_names:
skip_tags.add("msc.linear_bias")
if "nn.batch_matmul" not in op_names:
skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.linear"))
if "nn.dense" not in op_names:
skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.matmul"))
if "take" not in op_names:
skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.embedding"))
if "erf" not in op_names:
skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.gelu"))
valid_patterns = [p for p in patterns if p[0] not in skip_tags]
return valid_patterns


def from_relay(
mod: tvm.IRModule,
params: Optional[Dict[str, tvm.nd.array]] = None,
trans_config: Optional[Dict[str, str]] = None,
build_config: Optional[Dict[str, str]] = None,
opt_config: Optional[Dict[str, str]] = None,
) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]:
"""Change IRModule to MSCGraph.

Parameters
----------
mod: IRModule
The IRModule of relay.
params: dict of <string:tvm.ndarray>
The parameters of the IRModule.
trans_config: dict
The config for transform IRModule.
build_config: dict
The config for build MSCGraph.
opt_config: dict
The config for optimize the relay before translate.

Returns
-------
graph: tvm.contrib.msc.core.ir.MSCGraph
The translated graph.
weights: dict of <string:tvm.ndarray>
The weights from the IRModule.
"""

trans_config = msc_utils.copy_dict(trans_config)
build_config = msc_utils.copy_dict(build_config)
opt_config = msc_utils.copy_dict(opt_config)
# TODO(tong.meng): optimize before translate?
opt_level = opt_config.get("opt_level", 0)
if params:
mod["main"] = bind_params_by_name(mod["main"], params)
if opt_level > 0:
target = opt_config.get("target", "llvm")
disabled_pass = opt_config.get("disabled_pass", []) + [
"SimplifyInference",
"CanonicalizeOps",
"FuseOps",
"AlterOpLayout",
]
with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass):
mod, params = tvm.relay.optimize(mod, target=target, params=params)
patterns = get_relay_patterns(mod)
passes = [
tvm.relay.transform.InferType(),
tvm.relay.transform.MergeComposite(patterns),
msc_transform.SetExprName(as_relax=False),
]
mod = tvm.transform.Sequential(passes)(mod)
graph = _ffi_api.BuildFromRelay(mod, "main", msc_utils.dump_dict(build_config))
t_weights = _ffi_api.GetRelayWeights(mod, "main")
return graph, normalize_weights(t_weights, graph)


@tvm.relax.expr_functor.visitor
class BYOCChecker(PyExprVisitor):
"""Checker to check if any non-target ops exist"""
Expand Down
Loading
Loading