Skip to content

Commit

Permalink
Migrate module
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 11, 2020
1 parent 7ab08e2 commit 4de60af
Show file tree
Hide file tree
Showing 123 changed files with 478 additions and 500 deletions.
4 changes: 2 additions & 2 deletions apps/benchmark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_network(name, batch_size, dtype='float32'):
Returns
-------
net: relay.Module
net: tvm.IRModule
The relay function of network definition
params: dict
The random parameters for benchmark
Expand Down Expand Up @@ -70,7 +70,7 @@ def get_network(name, batch_size, dtype='float32'):
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = net["main"]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
net = relay.Module.from_expr(net)
net = tvm.IRModule.from_expr(net)
else:
raise ValueError("Unsupported network: " + name)

Expand Down
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .runtime import ndarray as nd

# tvm.ir
from .ir import IRModule
from .ir import transform
from . import ir

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(self, graph, input_shapes, records, target_ops,
self._logger.propagate = False

# Generate workload and schedule dictionaries.
if isinstance(graph, relay.Module):
if isinstance(graph, tvm.IRModule):
graph = graph["main"]

if isinstance(graph, relay.expr.Function):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):

def _infer_type(node):
"""A method to infer the type of a relay expression."""
mod = relay.Module.from_expr(node)
mod = tvm.IRModule.from_expr(node)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(node, relay.Function) else entry.body
Expand Down Expand Up @@ -136,7 +136,7 @@ def _traverse_expr(node):
free_var = relay.Var("var_%d" % i, input_type)
params.append(free_var)
call = relay.Call(node.op, params, node.attrs)
mod = relay.Module.from_expr(relay.Function(params, call))
mod = tvm.IRModule.from_expr(relay.Function(params, call))
relay.backend.compile_engine.get().clear()
build_thread = threading.Thread(target=relay.build,
args=(mod,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
rebind_dict[var] = updated_input_dict[var.name_hint]
updated_expr = relay.expr.bind(expr, rebind_dict)

mod = relay.Module.from_expr(updated_expr)
mod = tvm.IRModule.from_expr(updated_expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(updated_expr, relay.Function) else entry.body
8 changes: 4 additions & 4 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def extract_from_program(mod, params, ops, target, target_host=None,
Parameters
----------
mod: relay.module.Module or relay.expr.Function
mod: tvm.IRModule or relay.expr.Function
The module or function to tune
params: dict of str to numpy array
The associated parameters of the program
Expand Down Expand Up @@ -95,7 +95,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
Parameters
----------
mods: List[relay.module.Module] or List[relay.expr.Function]
mods: List[tvm.IRModule] or List[relay.expr.Function]
The list of modules or functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
Expand Down Expand Up @@ -151,8 +151,8 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,

for mod, param in zip(mods, params):
if isinstance(mod, relay.expr.Function):
mod = relay.Module.from_expr(mod)
assert isinstance(mod, relay.module.Module), \
mod = tvm.IRModule.from_expr(mod)
assert isinstance(mod, tvm.IRModule), \
"only support relay Module or Function to be tuned"
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, arg1, ctx=None, shape=None):
The corresponding a dense numpy array,
or a tuple for constructing a sparse matrix directly.
ctx: tvm.TVMContext
ctx: tvmContext
The corresponding context.
shape : tuple of int
Expand Down
1 change: 1 addition & 0 deletions python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
from .type_relation import TypeCall, TypeRelation
from .tensor_type import TensorType
from .adt import Constructor, TypeData
from .module import IRModule
from . import transform
92 changes: 41 additions & 51 deletions python/tvm/relay/module.py → python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,34 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global module storing everything needed to interpret or compile a Relay program."""
import os
from .base import register_relay_node, RelayNode
from .. import register_func
from .._ffi import base as _base
from . import _make
from . import _module
from . import expr as _expr
from . import ty as _ty
"""IRModule that holds the functions and type definitions."""
from tvm._ffi.base import string_types
import tvm._ffi

__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
from .base import Node
from . import expr as _expr
from . import type as _ty
from . import _ffi_api

@register_func("tvm.relay.std_path")
def _std_path():
global __STD_PATH__
return __STD_PATH__

@register_relay_node
class Module(RelayNode):
"""The global Relay module containing collection of functions.
@tvm._ffi.register_object("relay.Module")
class IRModule(Node):
"""IRModule that holds functions and type definitions.
Each global function is identified by an unique tvm.relay.GlobalVar.
tvm.relay.GlobalVar and Module is necessary in order to enable
recursions in function to avoid cyclic reference in the function.x
IRModule is the basic unit for all IR transformations across the stack.
Parameters
----------
functions: Optional[dict].
Map of global var to Function
Map of global var to BaseFunc
"""
def __init__(self, functions=None, type_definitions=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
mapped_funcs = {}
for k, v in functions.items():
if isinstance(k, _base.string_types):
if isinstance(k, string_types):
k = _expr.GlobalVar(k)
if not isinstance(k, _expr.GlobalVar):
raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
Expand All @@ -62,13 +52,13 @@ def __init__(self, functions=None, type_definitions=None):
elif isinstance(type_definitions, dict):
mapped_type_defs = {}
for k, v in type_definitions.items():
if isinstance(k, _base.string_types):
if isinstance(k, string_types):
k = _ty.GlobalTypeVar(k)
if not isinstance(k, _ty.GlobalTypeVar):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_make.Module, functions, type_definitions)
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)


def __setitem__(self, var, val):
Expand All @@ -86,17 +76,17 @@ def __setitem__(self, var, val):

def _add(self, var, val, update=False):
if isinstance(val, _expr.RelayExpr):
if isinstance(var, _base.string_types):
if _module.Module_ContainGlobalVar(self, var):
var = _module.Module_GetGlobalVar(self, var)
if isinstance(var, string_types):
if _ffi_api.Module_ContainGlobalVar(self, var):
var = _ffi_api.Module_GetGlobalVar(self, var)
else:
var = _expr.GlobalVar(var)
_module.Module_Add(self, var, val, update)
_ffi_api.Module_Add(self, var, val, update)
else:
assert isinstance(val, _ty.Type)
if isinstance(var, _base.string_types):
if isinstance(var, string_types):
var = _ty.GlobalTypeVar(var)
_module.Module_AddDef(self, var, val, update)
_ffi_api.Module_AddDef(self, var, val, update)

def __getitem__(self, var):
"""Lookup a global definition by name or by variable.
Expand All @@ -111,12 +101,11 @@ def __getitem__(self, var):
val: Union[Function, Type]
The definition referenced by :code:`var` (either a function or type).
"""
if isinstance(var, _base.string_types):
return _module.Module_Lookup_str(self, var)
elif isinstance(var, _expr.GlobalVar):
return _module.Module_Lookup(self, var)
else:
return _module.Module_LookupDef(self, var)
if isinstance(var, string_types):
return _ffi_api.Module_Lookup_str(self, var)
if isinstance(var, _expr.GlobalVar):
return _ffi_api.Module_Lookup(self, var)
return _ffi_api.Module_LookupDef(self, var)

def update(self, other):
"""Insert functions in another Module to current one.
Expand All @@ -128,7 +117,7 @@ def update(self, other):
"""
if isinstance(other, dict):
other = Module(other)
return _module.Module_Update(self, other)
return _ffi_api.Module_Update(self, other)

def get_global_var(self, name):
"""Get a global variable in the function by name.
Expand All @@ -145,9 +134,9 @@ def get_global_var(self, name):
Raises
------
tvm.TVMError if we cannot find corresponding global var.
tvm.error.TVMError if we cannot find corresponding global var.
"""
return _module.Module_GetGlobalVar(self, name)
return _ffi_api.Module_GetGlobalVar(self, name)

def get_global_vars(self):
"""Collect all global vars defined in this module.
Expand All @@ -157,7 +146,7 @@ def get_global_vars(self):
global_vars: tvm.Array[GlobalVar]
An array of global vars.
"""
return _module.Module_GetGlobalVars(self)
return _ffi_api.Module_GetGlobalVars(self)

def get_global_type_vars(self):
"""Collect all global type vars defined in this module.
Expand All @@ -167,7 +156,7 @@ def get_global_type_vars(self):
global_type_vars: tvm.Array[GlobalTypeVar]
An array of global type vars.
"""
return _module.Module_GetGlobalTypeVars(self)
return _ffi_api.Module_GetGlobalTypeVars(self)

def get_global_type_var(self, name):
"""Get a global type variable in the function by name.
Expand All @@ -184,9 +173,9 @@ def get_global_type_var(self, name):
Raises
------
tvm.TVMError if we cannot find corresponding global type var.
tvm.error.TVMError if we cannot find corresponding global type var.
"""
return _module.Module_GetGlobalTypeVar(self, name)
return _ffi_api.Module_GetGlobalTypeVar(self, name)

def get_constructor(self, tag):
"""Look up an ADT constructor by tag.
Expand All @@ -203,24 +192,25 @@ def get_constructor(self, tag):
Raises
------
tvm.TVMError if the corresponding constructor cannot be found.
tvm.error.TVMError if the corresponding constructor cannot be found.
"""
return _module.Module_LookupTag(self, tag)
return _ffi_api.Module_LookupTag(self, tag)

@staticmethod
def from_expr(expr, functions=None, type_defs=None):
"""Construct a module from a standalone expression.
Parameters
----------
expr: Expr
expr: RelayExpr
The starting expression
global_funcs: Optional[dict]
Map of global vars to function definitions
type_defs: Optional[dict]
Map of global type vars to type definitions
Returns
-------
mod: Module
Expand All @@ -230,10 +220,10 @@ def from_expr(expr, functions=None, type_defs=None):
"""
funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {}
return _module.Module_FromExpr(expr, funcs, defs)
return _ffi_api.Module_FromExpr(expr, funcs, defs)

def _import(self, file_to_import):
return _module.Module_Import(self, file_to_import)
return _ffi_api.Module_Import(self, file_to_import)

def import_from_std(self, file_to_import):
return _module.Module_ImportFromStd(self, file_to_import)
return _ffi_api.Module_ImportFromStd(self, file_to_import)
8 changes: 4 additions & 4 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,20 @@ def __call__(self, mod):
Parameters
----------
mod : tvm.relay.Module
mod : tvm.IRModule
The module that a certain optimization is performed on.
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The updated module after applying this pass.
"""
return _ffi_transform_api.RunPass(self, mod)


@tvm._ffi.register_object("relay.ModulePass")
class ModulePass(Pass):
"""A pass that works on tvm.relay.Module. Users don't need to interact with
"""A pass that works on tvm.IRModule. Users don't need to interact with
this class directly. Instead, a module pass should be created through
`module_pass`, because the design of the `module_pass` API is flexible
enough to handle the creation of a module pass in different manners. In
Expand Down Expand Up @@ -293,7 +293,7 @@ def transform(mod, ctx):
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod = tvm.IRModule({gv: func})
new_mod.update(mod)
return new_mod
Expand Down
5 changes: 0 additions & 5 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
import os
from sys import setrecursionlimit
from ..api import register_func
Expand All @@ -25,7 +24,6 @@
from . import expr
from . import type_functor
from . import expr_functor
from . import module
from . import adt
from . import analysis
from . import transform
Expand Down Expand Up @@ -67,9 +65,6 @@
# Span
Span = base.Span

# Env
Module = module.Module

# Type
Type = ty.Type
TupleType = ty.TupleType
Expand Down
21 changes: 0 additions & 21 deletions python/tvm/relay/_module.py

This file was deleted.

Loading

0 comments on commit 4de60af

Please sign in to comment.