diff --git a/docs/langref/relay_expr.rst b/docs/langref/relay_expr.rst index 66bfe43a04d63..3b93360453eb3 100644 --- a/docs/langref/relay_expr.rst +++ b/docs/langref/relay_expr.rst @@ -120,7 +120,7 @@ Additionally, functions in Relay are higher-order, which means that a function c function or returned by a function, as function expressions evaluate to closures (see the `Closures`_ subsection), which are values like tensors and tuples. -See :py:class:`~tvm.relay.expr.Function` for the definition and documentation of function nodes. +See :py:class:`~tvm.relay.function.Function` for the definition and documentation of function nodes. Syntax ~~~~~~ diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index f1a075637ae94..e7b4694cc53d6 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -69,7 +69,7 @@ def __init__(self, graph, input_shapes, records, target_ops, target_op in the input graph and layout transformation benchmark need to be executed before initialization. - graph : tvm.relay.Expr.Function + graph : tvm.relay.function.Function Input graph input_shapes : dict of str to tuple. @@ -143,7 +143,7 @@ def __init__(self, graph, input_shapes, records, target_ops, if isinstance(graph, tvm.IRModule): graph = graph["main"] - if isinstance(graph, relay.expr.Function): + if isinstance(graph, relay.function.Function): node_dict = {} graph = bind_inputs(graph, input_shapes, dtype) expr2graph(graph, self._target_ops, node_dict, self._node_list) diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index f1dd40440532b..8470fb6815996 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -21,7 +21,8 @@ import tvm from tvm import relay, autotvm from tvm.relay import transform -from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple +from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple +from tvm.relay.function import Function from tvm.relay.ty import TupleType, TensorType from tvm.autotvm.task import TaskExtractEnv diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index cd8d32fb2d68b..a7cbef74e5c27 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -61,7 +61,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None): Parameters ---------- - mod: tvm.IRModule or relay.expr.Function + mod: tvm.IRModule or relay.function.Function The module or function to tune params: dict of str to numpy array The associated parameters of the program @@ -88,7 +88,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No Parameters ---------- - mods: List[tvm.IRModule] or List[relay.expr.Function] + mods: List[tvm.IRModule] or List[relay.function.Function] The list of modules or functions to tune params: List of dict of str to numpy array The associated parameters of the programs @@ -118,7 +118,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No logger.disabled = True for mod, param in zip(mods, params): - if isinstance(mod, relay.expr.Function): + if isinstance(mod, relay.function.Function): mod = tvm.IRModule.from_expr(mod) assert isinstance(mod, tvm.IRModule), \ "only support relay Module or Function to be tuned" diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index b1aac3e606a2b..95545c8cd5596 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -22,6 +22,7 @@ from . import base from . import ty from . import expr +from . import function from . import type_functor from . import expr_functor from . import adt @@ -87,7 +88,7 @@ Tuple = expr.Tuple Var = expr.Var GlobalVar = expr.GlobalVar -Function = expr.Function +Function = function.Function Call = expr.Call Let = expr.Let If = expr.If diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 49bdbb393c2e5..4a73e572f9240 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -43,6 +43,7 @@ def __new__(cls, *args, **kwds): from .base import Span, SourceName from . import adt from . import expr +from . import function from . import ty from . import op @@ -481,7 +482,7 @@ def visitMeta(self, ctx: RelayParser.MetaContext): def mk_func( self, ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \ - -> expr.Function: + -> function.Function: """Construct a function from either a Func or Defn.""" # Enter var scope early to put params in scope. self.enter_var_scope() @@ -511,10 +512,10 @@ def mk_func( self.exit_var_scope() attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None - return expr.Function(var_list, body, ret_type, type_params, attrs) + return function.Function(var_list, body, ret_type, type_params, attrs) @spanify - def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function: + def visitFunc(self, ctx: RelayParser.FuncContext) -> function.Function: return self.mk_func(ctx) # TODO: how to set spans for definitions? diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index beb3c6599e287..722f3b0630dee 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -421,7 +421,7 @@ def extract_fused_functions(mod): Returns ------- - ret : Dict[int, tvm.relay.ir.expr.Function] + ret : Dict[int, tvm.relay.function.Function] A module containing only fused primitive functions """ ret_mod = _ffi_api.ExtractFusedFunctions()(mod) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 03d91d5beb0ff..f023335feca17 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -26,6 +26,7 @@ from ... import target as _target from ... import autotvm from .. import expr as _expr +from .. import function as _function from .. import op as _op from .. import ty as _ty from . import _backend @@ -65,7 +66,7 @@ class CCacheValue(Object): def _get_cache_key(source_func, target): - if isinstance(source_func, _expr.Function): + if isinstance(source_func, _function.Function): if isinstance(target, str): target = _target.create(target) if not target: diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index ab39f7c564468..9c4be2975d6cd 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -27,7 +27,8 @@ from . import _backend from .. import _make, analysis, transform from ... import nd -from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const +from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const +from ..function import Function from ..scope_builder import ScopeBuilder diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d1add2726d782..30c5971e32b91 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -29,6 +29,7 @@ from . import _build_module from . import ty as _ty from . import expr as _expr +from . import function as _function from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -218,16 +219,16 @@ def build(mod, target=None, target_host=None, params=None): params : dict The parameters of the final graph. """ - if not isinstance(mod, (IRModule, _expr.Function)): + if not isinstance(mod, (IRModule, _function.Function)): raise ValueError("Type of input parameter mod must be tvm.IRModule") - if isinstance(mod, _expr.Function): + if isinstance(mod, _function.Function): if params: mod = bind_params_by_name(mod, params) mod = IRModule.from_expr(mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " - "instead of deprecated parameter mod (tvm.relay.expr.Function)", + "instead of deprecated parameter mod (tvm.relay.function.Function)", DeprecationWarning) target = _update_target(target) @@ -276,16 +277,16 @@ def optimize(mod, target=None, params=None): params : dict The parameters of the final graph. """ - if not isinstance(mod, (IRModule, _expr.Function)): + if not isinstance(mod, (IRModule, _function.Function)): raise ValueError("Type of input parameter mod must be tvm.IRModule") - if isinstance(mod, _expr.Function): + if isinstance(mod, _function.Function): if params: mod = bind_params_by_name(mod, params) mod = IRModule.from_expr(mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " - "instead of deprecated parameter func (tvm.relay.expr.Function)", + "instead of deprecated parameter func (tvm.relay.function.Function)", DeprecationWarning) target = _update_target(target) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 380cdf7d90efa..ff13683949173 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -22,8 +22,8 @@ import numpy as _np import tvm._ffi from tvm._ffi import base as _base -from tvm.runtime import NDArray, convert, ndarray as _nd -from tvm.ir import RelayExpr, GlobalVar, BaseFunc +from tvm.runtime import NDArray, ndarray as _nd +from tvm.ir import RelayExpr, GlobalVar from .base import RelayNode from . import _ffi_api @@ -225,68 +225,6 @@ def name_hint(self): return name -@tvm._ffi.register_object("relay.Function") -class Function(BaseFunc): - """A function declaration expression. - - Parameters - ---------- - params: List[tvm.relay.Var] - List of input parameters to the function. - - body: tvm.relay.Expr - The body of the function. - - ret_type: Optional[tvm.relay.Type] - The return type annotation of the function. - - type_params: Optional[List[tvm.relay.TypeParam]] - The additional type parameters, this is only - used in advanced usecase of template functions. - """ - def __init__(self, - params, - body, - ret_type=None, - type_params=None, - attrs=None): - if type_params is None: - type_params = convert([]) - - self.__init_handle_by_constructor__( - _ffi_api.Function, params, body, ret_type, type_params, attrs) - - def __call__(self, *args): - """Invoke the global function. - - Parameters - ---------- - args: List[relay.Expr] - Arguments. - """ - return Call(self, args, None, None) - - def with_attr(self, attr_key, attr_value): - """Create a new copy of the function and update the attribute - - Parameters - ---------- - attr_key : str - The attribute key to use. - - attr_value : Object - The new attribute value. - - Returns - ------- - func : Function - A new copy of the function - """ - return _ffi_api.FunctionWithAttr( - self, attr_key, convert(attr_value)) - - - @tvm._ffi.register_object("relay.Call") class Call(ExprWithOp): """Function call node in Relay. diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 8d6923920979e..874a3a75d5bf5 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -17,7 +17,8 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression functor of Relay.""" -from .expr import Function, Call, Let, Var, GlobalVar +from .function import Function +from .expr import Call, Let, Var, GlobalVar from .expr import If, Tuple, TupleGetItem, Constant from .expr import RefCreate, RefRead, RefWrite from .adt import Constructor, Match, Clause diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index da0cc6479818e..f4fcd9237ed34 100644 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -21,6 +21,7 @@ from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from ... import nd as _nd from .common import AttrCvt, Renamer @@ -451,7 +452,7 @@ def from_caffe2(self, init_net, predict_net): else: outputs = out[0] - func = _expr.Function(analysis.free_vars(outputs), outputs) + func = _function.Function(analysis.free_vars(outputs), outputs) self._mod["main"] = func return self._mod, self._params @@ -517,7 +518,7 @@ def _convert_operator(self, ---------- op_type : str Operator name, such as Convolution, FullyConnected - inputs : list of tvm.relay.expr.Function + inputs : list of tvm.relay.function.Function List of input inputs. args : dict Dict of operator attributes @@ -530,7 +531,7 @@ def _convert_operator(self, Returns ------- - func : tvm.relay.expr.Function + func : tvm.relay.function.Function Converted relay function """ identity_list = identity_list if identity_list else _identity_list diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index d427fe953085d..6185121228c37 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -24,6 +24,7 @@ from topi.util import get_const_tuple from .. import expr as _expr +from .. import function as _function from .. import transform as _transform from .. import op as _op from .. import analysis @@ -459,7 +460,7 @@ def infer_type(node, mod=None): new_mod.update(mod) new_mod = _transform.InferType()(new_mod) entry = new_mod["main"] - return entry if isinstance(node, _expr.Function) else entry.body + return entry if isinstance(node, _function.Function) else entry.body def infer_shape(inputs, mod=None): """A method to get the output type of an intermediate node in the graph.""" @@ -491,7 +492,7 @@ def infer_value(input_val, params): # Check that all free variables have associated parameters. assert all(var.name_hint in params.keys() for var in analysis.free_vars( input_val)), "All inputs to infer must be available in params." - func = _expr.Function(analysis.free_vars(input_val), input_val) + func = _function.Function(analysis.free_vars(input_val), input_val) with tvm.relay.build_config(opt_level=0): graph, lib, params = tvm.relay.build(func, target="llvm", params=params) ctx = tvm.cpu(0) diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 0e5b64cbbacc4..6658803b3ade2 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -24,6 +24,7 @@ from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from ... import nd as _nd from ..._ffi import base as _base @@ -503,6 +504,6 @@ def from_coreml(model, shape=None): for o in spec.description.output] # for now return first output outexpr = outexpr[0] - func = _expr.Function(analysis.free_vars(outexpr), outexpr) + func = _function.Function(analysis.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} return IRModule.from_expr(func), params diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py index 0dae645cd9a4a..936d7c0dc87fd 100644 --- a/python/tvm/relay/frontend/darknet.py +++ b/python/tvm/relay/frontend/darknet.py @@ -26,6 +26,7 @@ from .. import analysis from .. import expr as _expr +from .. import function as _function from .common import get_relay_op, new_var __all__ = ['from_darknet'] @@ -821,7 +822,7 @@ def from_darknet(self): outputs = _as_list(sym) + self._outs outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - sym = _expr.Function(analysis.free_vars(outputs), outputs) + sym = _function.Function(analysis.free_vars(outputs), outputs) return IRModule.from_expr(sym), self._tvmparams def from_darknet(net, diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index adb28c4664542..090bd4c5b646d 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -23,6 +23,7 @@ from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from ... import nd as _nd from .common import ExprTable, new_var @@ -914,6 +915,6 @@ def _convert_input_layer(keras_layer): outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \ for oc in model._output_coordinates] outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr) - func = _expr.Function(analysis.free_vars(outexpr), outexpr) + func = _function.Function(analysis.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} return IRModule.from_expr(func), params diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index ba93bb2e3b81c..17be3686f3181 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -25,6 +25,7 @@ from topi.util import get_const_tuple from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from .. import scope_builder as _scope_builder from ... import nd as _nd @@ -1096,7 +1097,7 @@ def _mx_cond(inputs, attrs, subgraphs): else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args] else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes, else_arg_dtype_info) sb.ret(_expr.Call(else_func, else_args)) - func = _expr.Function(input_args, sb.get()) + func = _function.Function(input_args, sb.get()) ret = _expr.Call(func, inputs) if num_outputs > 1: ret = _expr.TupleWrapper(ret, num_outputs) @@ -1969,7 +1970,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None): outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(analysis.free_vars(outputs), outputs) + func = _function.Function(analysis.free_vars(outputs), outputs) return func diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7f417d39e9f81..e1b0a7f01749e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -24,6 +24,7 @@ from ... import nd as _nd from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels @@ -1708,7 +1709,7 @@ def from_onnx(self, graph, opset): # now return the outputs outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(analysis.free_vars(outputs), outputs) + func = _function.Function(analysis.free_vars(outputs), outputs) return IRModule.from_expr(func), self._params def _parse_value_proto(self, value_proto): @@ -1774,7 +1775,7 @@ def _convert_operator(self, ---------- op_name : str Operator name, such as Convolution, FullyConnected - inputs : list of tvm.relay.expr.Function + inputs : list of tvm.relay.function.Function List of inputs. attrs : dict Dict of operator attributes @@ -1783,7 +1784,7 @@ def _convert_operator(self, Returns ------- - sym : tvm.relay.expr.Function + sym : tvm.relay.function.Function Converted relay function """ convert_map = _get_convert_map(opset) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 3dca3652f2bed..e0da863a98d09 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -31,6 +31,7 @@ from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from ..expr_functor import ExprMutator from .common import AttrCvt, get_relay_op @@ -2461,7 +2462,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): out.append(out_rnn) out = out[0] if len(out) == 1 else _expr.Tuple(out) - func = _expr.Function(analysis.free_vars(out), out) + func = _function.Function(analysis.free_vars(out), out) self._mod["main"] = func return self._mod, self._params diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 95f757989925e..aa5157024eeb2 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -24,6 +24,7 @@ from tvm import relay from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from .. import qnn as _qnn from ... import nd as _nd @@ -2365,6 +2366,6 @@ def from_tflite(model, shape_dict, dtype_dict): params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()} outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(analysis.free_vars(outputs), outputs) + func = _function.Function(analysis.free_vars(outputs), outputs) mod = IRModule.from_expr(func) return mod, params diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py new file mode 100644 index 0000000000000..786a7f4cfc249 --- /dev/null +++ b/python/tvm/relay/function.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name, unused-import +"""The expression nodes of Relay.""" +from __future__ import absolute_import + +import tvm._ffi +from tvm.runtime import convert +from tvm.ir import BaseFunc + +from .expr import Call +from . import _ffi_api + +@tvm._ffi.register_object("relay.Function") +class Function(BaseFunc): + """A function declaration expression. + + Parameters + ---------- + params: List[tvm.relay.Var] + List of input parameters to the function. + + body: tvm.relay.Expr + The body of the function. + + ret_type: Optional[tvm.relay.Type] + The return type annotation of the function. + + type_params: Optional[List[tvm.relay.TypeParam]] + The additional type parameters, this is only + used in advanced usecase of template functions. + """ + def __init__(self, + params, + body, + ret_type=None, + type_params=None, + attrs=None): + if type_params is None: + type_params = convert([]) + + self.__init_handle_by_constructor__( + _ffi_api.Function, params, body, ret_type, type_params, attrs) + + def __call__(self, *args): + """Invoke the global function. + + Parameters + ---------- + args: List[relay.Expr] + Arguments. + """ + return Call(self, args, None, None) + + def with_attr(self, attr_key, attr_value): + """Create a new copy of the function and update the attribute + + Parameters + ---------- + attr_key : str + The attribute key to use. + + attr_value : Object + The new attribute value. + + Returns + ------- + func : Function + A new copy of the function + """ + return _ffi_api.FunctionWithAttr( + self, attr_key, convert(attr_value)) diff --git a/python/tvm/relay/loops.py b/python/tvm/relay/loops.py index 8e066abda031a..9af6811ea2f63 100644 --- a/python/tvm/relay/loops.py +++ b/python/tvm/relay/loops.py @@ -20,6 +20,7 @@ """ from .scope_builder import ScopeBuilder from . import expr as _expr +from . import function as _function def while_loop(cond, loop_vars, loop_bodies): """ @@ -60,6 +61,6 @@ def while_loop(cond, loop_vars, loop_bodies): with sb.else_scope(): sb.ret(_expr.Tuple(fresh_vars)) - func = _expr.Function(fresh_vars, sb.get()) + func = _function.Function(fresh_vars, sb.get()) let = _expr.Let(loop, func, loop) return let diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 5288a2e080113..0e64a2fd12485 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -19,7 +19,8 @@ from tvm.ir import IRModule from .ty import GlobalTypeVar, TensorType, Any, scalar_type -from .expr import Var, Function, GlobalVar, If, const +from .expr import Var, GlobalVar, If, const +from .function import Function from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py index eb71120610d3f..4906eefe5bfdb 100644 --- a/python/tvm/relay/testing/nat.py +++ b/python/tvm/relay/testing/nat.py @@ -21,7 +21,8 @@ from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar from tvm.relay.backend.interpreter import ConstructorValue -from tvm.relay.expr import Var, Function, GlobalVar +from tvm.relay.expr import Var, GlobalVar +from tvm.relay.function import Function from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType def define_nat_adt(prelude): diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index eacfe379137fb..e850000522315 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -23,7 +23,8 @@ from tvm import relay from tvm.relay.adt import Pattern from tvm.relay.backend import compile_engine -from tvm.relay.expr import Expr, Function, GlobalVar, Var +from tvm.relay.expr import Expr, GlobalVar, Var +from tvm.relay.function import Function from tvm.relay.expr_functor import ExprFunctor OUTPUT_VAR_NAME = '_py_out' diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index d371edb31fca1..63ce0dda656c3 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -27,10 +27,10 @@ namespace tvm { namespace relay { Function::Function(tvm::Array params, - Expr body, - Type ret_type, - tvm::Array type_params, - DictAttrs attrs) { + Expr body, + Type ret_type, + tvm::Array type_params, + DictAttrs attrs) { ObjectPtr n = make_object(); CHECK(params.defined()); CHECK(type_params.defined()); @@ -71,7 +71,6 @@ TVM_REGISTER_GLOBAL("relay.ir.Function") return Function(params, body, ret_type, ty_params, attrs); }); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get());