Skip to content

Commit

Permalink
create function.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 17, 2020
1 parent bd78fb2 commit 9c5c127
Show file tree
Hide file tree
Showing 27 changed files with 152 additions and 109 deletions.
2 changes: 1 addition & 1 deletion docs/langref/relay_expr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Additionally, functions in Relay are higher-order, which means that a function c
function or returned by a function, as function expressions evaluate to closures (see the `Closures`_ subsection),
which are values like tensors and tuples.

See :py:class:`~tvm.relay.expr.Function` for the definition and documentation of function nodes.
See :py:class:`~tvm.relay.function.Function` for the definition and documentation of function nodes.

Syntax
~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, graph, input_shapes, records, target_ops,
target_op in the input graph and layout transformation benchmark need to be
executed before initialization.
graph : tvm.relay.Expr.Function
graph : tvm.relay.function.Function
Input graph
input_shapes : dict of str to tuple.
Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(self, graph, input_shapes, records, target_ops,
if isinstance(graph, tvm.IRModule):
graph = graph["main"]

if isinstance(graph, relay.expr.Function):
if isinstance(graph, relay.function.Function):
node_dict = {}
graph = bind_inputs(graph, input_shapes, dtype)
expr2graph(graph, self._target_ops, node_dict, self._node_list)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import tvm
from tvm import relay, autotvm
from tvm.relay import transform
from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple
from tvm.relay.function import Function
from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
Parameters
----------
mod: tvm.IRModule or relay.expr.Function
mod: tvm.IRModule or relay.function.Function
The module or function to tune
params: dict of str to numpy array
The associated parameters of the program
Expand All @@ -88,7 +88,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
Parameters
----------
mods: List[tvm.IRModule] or List[relay.expr.Function]
mods: List[tvm.IRModule] or List[relay.function.Function]
The list of modules or functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
Expand Down Expand Up @@ -118,7 +118,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
logger.disabled = True

for mod, param in zip(mods, params):
if isinstance(mod, relay.expr.Function):
if isinstance(mod, relay.function.Function):
mod = tvm.IRModule.from_expr(mod)
assert isinstance(mod, tvm.IRModule), \
"only support relay Module or Function to be tuned"
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from . import base
from . import ty
from . import expr
from . import function
from . import type_functor
from . import expr_functor
from . import adt
Expand Down Expand Up @@ -87,7 +88,7 @@
Tuple = expr.Tuple
Var = expr.Var
GlobalVar = expr.GlobalVar
Function = expr.Function
Function = function.Function
Call = expr.Call
Let = expr.Let
If = expr.If
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __new__(cls, *args, **kwds):
from .base import Span, SourceName
from . import adt
from . import expr
from . import function
from . import ty
from . import op

Expand Down Expand Up @@ -481,7 +482,7 @@ def visitMeta(self, ctx: RelayParser.MetaContext):
def mk_func(
self,
ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \
-> expr.Function:
-> function.Function:
"""Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope.
self.enter_var_scope()
Expand Down Expand Up @@ -511,10 +512,10 @@ def mk_func(
self.exit_var_scope()

attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None
return expr.Function(var_list, body, ret_type, type_params, attrs)
return function.Function(var_list, body, ret_type, type_params, attrs)

@spanify
def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function:
def visitFunc(self, ctx: RelayParser.FuncContext) -> function.Function:
return self.mk_func(ctx)

# TODO: how to set spans for definitions?
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def extract_fused_functions(mod):
Returns
-------
ret : Dict[int, tvm.relay.ir.expr.Function]
ret : Dict[int, tvm.relay.function.Function]
A module containing only fused primitive functions
"""
ret_mod = _ffi_api.ExtractFusedFunctions()(mod)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ... import target as _target
from ... import autotvm
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import ty as _ty
from . import _backend
Expand Down Expand Up @@ -65,7 +66,7 @@ class CCacheValue(Object):


def _get_cache_key(source_func, target):
if isinstance(source_func, _expr.Function):
if isinstance(source_func, _function.Function):
if isinstance(target, str):
target = _target.create(target)
if not target:
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from . import _backend
from .. import _make, analysis, transform
from ... import nd
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const
from ..function import Function
from ..scope_builder import ScopeBuilder


Expand Down
13 changes: 7 additions & 6 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from . import _build_module
from . import ty as _ty
from . import expr as _expr
from . import function as _function
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor

Expand Down Expand Up @@ -218,16 +219,16 @@ def build(mod, target=None, target_host=None, params=None):
params : dict
The parameters of the final graph.
"""
if not isinstance(mod, (IRModule, _expr.Function)):
if not isinstance(mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")

if isinstance(mod, _expr.Function):
if isinstance(mod, _function.Function):
if params:
mod = bind_params_by_name(mod, params)
mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter mod (tvm.relay.expr.Function)",
"instead of deprecated parameter mod (tvm.relay.function.Function)",
DeprecationWarning)

target = _update_target(target)
Expand Down Expand Up @@ -276,16 +277,16 @@ def optimize(mod, target=None, params=None):
params : dict
The parameters of the final graph.
"""
if not isinstance(mod, (IRModule, _expr.Function)):
if not isinstance(mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")

if isinstance(mod, _expr.Function):
if isinstance(mod, _function.Function):
if params:
mod = bind_params_by_name(mod, params)
mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
"instead of deprecated parameter func (tvm.relay.function.Function)",
DeprecationWarning)

target = _update_target(target)
Expand Down
66 changes: 2 additions & 64 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import numpy as _np
import tvm._ffi
from tvm._ffi import base as _base
from tvm.runtime import NDArray, convert, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar, BaseFunc
from tvm.runtime import NDArray, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar

from .base import RelayNode
from . import _ffi_api
Expand Down Expand Up @@ -225,68 +225,6 @@ def name_hint(self):
return name


@tvm._ffi.register_object("relay.Function")
class Function(BaseFunc):
"""A function declaration expression.
Parameters
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
body: tvm.relay.Expr
The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def __init__(self,
params,
body,
ret_type=None,
type_params=None,
attrs=None):
if type_params is None:
type_params = convert([])

self.__init_handle_by_constructor__(
_ffi_api.Function, params, body, ret_type, type_params, attrs)

def __call__(self, *args):
"""Invoke the global function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return Call(self, args, None, None)

def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.FunctionWithAttr(
self, attr_key, convert(attr_value))



@tvm._ffi.register_object("relay.Call")
class Call(ExprWithOp):
"""Function call node in Relay.
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression functor of Relay."""

from .expr import Function, Call, Let, Var, GlobalVar
from .function import Function
from .expr import Call, Let, Var, GlobalVar
from .expr import If, Tuple, TupleGetItem, Constant
from .expr import RefCreate, RefRead, RefWrite
from .adt import Constructor, Match, Clause
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relay/frontend/caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from ... import nd as _nd
from .common import AttrCvt, Renamer
Expand Down Expand Up @@ -451,7 +452,7 @@ def from_caffe2(self, init_net, predict_net):
else:
outputs = out[0]

func = _expr.Function(analysis.free_vars(outputs), outputs)
func = _function.Function(analysis.free_vars(outputs), outputs)
self._mod["main"] = func

return self._mod, self._params
Expand Down Expand Up @@ -517,7 +518,7 @@ def _convert_operator(self,
----------
op_type : str
Operator name, such as Convolution, FullyConnected
inputs : list of tvm.relay.expr.Function
inputs : list of tvm.relay.function.Function
List of input inputs.
args : dict
Dict of operator attributes
Expand All @@ -530,7 +531,7 @@ def _convert_operator(self,
Returns
-------
func : tvm.relay.expr.Function
func : tvm.relay.function.Function
Converted relay function
"""
identity_list = identity_list if identity_list else _identity_list
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from topi.util import get_const_tuple

from .. import expr as _expr
from .. import function as _function
from .. import transform as _transform
from .. import op as _op
from .. import analysis
Expand Down Expand Up @@ -459,7 +460,7 @@ def infer_type(node, mod=None):
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body
return entry if isinstance(node, _function.Function) else entry.body

def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
Expand Down Expand Up @@ -491,7 +492,7 @@ def infer_value(input_val, params):
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _expr.Function(analysis.free_vars(input_val), input_val)
func = _function.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from ... import nd as _nd
from ..._ffi import base as _base
Expand Down Expand Up @@ -503,6 +504,6 @@ def from_coreml(model, shape=None):
for o in spec.description.output]
# for now return first output
outexpr = outexpr[0]
func = _expr.Function(analysis.free_vars(outexpr), outexpr)
func = _function.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return IRModule.from_expr(func), params
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .common import get_relay_op, new_var

__all__ = ['from_darknet']
Expand Down Expand Up @@ -821,7 +822,7 @@ def from_darknet(self):

outputs = _as_list(sym) + self._outs
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
sym = _expr.Function(analysis.free_vars(outputs), outputs)
sym = _function.Function(analysis.free_vars(outputs), outputs)
return IRModule.from_expr(sym), self._tvmparams

def from_darknet(net,
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from ... import nd as _nd
from .common import ExprTable, new_var
Expand Down Expand Up @@ -914,6 +915,6 @@ def _convert_input_layer(keras_layer):
outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \
for oc in model._output_coordinates]
outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
func = _expr.Function(analysis.free_vars(outexpr), outexpr)
func = _function.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return IRModule.from_expr(func), params
Loading

0 comments on commit 9c5c127

Please sign in to comment.