diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 425ae723c509e..69c24008c10fb 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -36,6 +36,7 @@ # tvm.ir from .ir import IRModule from .ir import transform +from .ir import container from . import ir # others @@ -46,7 +47,6 @@ from . import make from . import ir_pass from . import codegen -from . import container from . import schedule from . import ir_builder diff --git a/python/tvm/api.py b/python/tvm/api.py index eb38654330b83..e7778d6cc5df7 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -19,8 +19,11 @@ from numbers import Integral as _Integral import tvm._ffi +import tvm.ir from tvm.runtime import convert, const, DataType +from tvm.ir import container as _container + from ._ffi.base import string_types, TVMError from ._ffi.registry import register_func, get_global_func, extract_ext_funcs @@ -29,7 +32,6 @@ from . import expr as _expr from . import tensor as _tensor from . import schedule as _schedule -from . import container as _container from . import tag as _tag int8 = "int8" @@ -69,29 +71,6 @@ def max_value(dtype): """ return _api_internal._max_value(dtype) - -def get_env_func(name): - """Get an EnvFunc by a global name. - - Parameters - ---------- - name: str - The name of the global function. - - Returns - ------- - env_func : EnvFunc - The result env function. - - Note - ---- - EnvFunc is a Object wrapper around - global function that can be serialized via its name. - This can be used to serialize function field in the language. - """ - return _api_internal._EnvFuncGet(name) - - def var(name="tindex", dtype=int32): """Create a new variable with specified name and dtype @@ -649,7 +628,7 @@ def _IterVar(dom, name, iter_type, thread_tag=''): raise TypeError("need to be list of ranges") dom = Range(dom[0], dom[1]) - if not isinstance(dom, _container.Range): + if not isinstance(dom, tvm.ir.Range): raise TypeError("dom need to be Range") name = name if name else 'iter' v = var(name) diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 1ac13a2c48b2b..768f438844188 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -24,6 +24,7 @@ import tvm.runtime from tvm.runtime import Object, ndarray +from tvm.ir import container from . import api from . import _api_internal from . import tensor @@ -31,7 +32,6 @@ from . import expr from . import ir_pass from . import stmt as _stmt -from . import container from . import codegen from . import target as _target from . import make diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index e873e1974d213..630c10fcf2dd1 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -16,12 +16,11 @@ # under the License. """Intrinsics of TVM-Python Hybrid Script for Python compilation time semantic support.""" - +from tvm.ir.container import Array from .. import api as _api from .. import expr as _expr from .. import make as _make from .. import target as _tgt -from ..container import Array from .. import ir_pass from ..stmt import For from .util import _internal_assert diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index cd2433e64a8c6..a0b2dfea60620 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -24,6 +24,7 @@ import numbers from enum import Enum +from tvm.ir.container import Array from .util import _internal_assert from . import calls @@ -32,7 +33,6 @@ from ..api import all as _all from ..api import any as _any -from ..container import Array from ..tensor import Tensor, Operation from .. import _api_internal as _tvm_internal from .. import expr as _expr diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 0883960fabfd6..8ef200a02b2c7 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -21,13 +21,14 @@ import logging import sys import numpy + +from tvm.ir.container import Array from .. import api as _api from .. import make as _make from .. import expr as _expr from .. import stmt as _stmt from .._ffi.base import numeric_types from ..tensor import Tensor -from ..container import Array #pylint: disable=invalid-name diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 88d7dd23eefa4..e3552b5fe0470 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -16,8 +16,8 @@ # under the License. # pylint: disable=unused-import """Common data structures across all IR variants.""" -from .base import SourceName, Span, Node, load_json, save_json -from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc +from .base import SourceName, Span, Node, EnvFunc, load_json, save_json +from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .type_relation import TypeCall, TypeRelation @@ -25,5 +25,6 @@ from .adt import Constructor, TypeData from .module import IRModule from .attrs import Attrs +from .container import Array, Map from . import transform diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index f2d5f89d8f658..3515e7528a7ed 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -87,6 +87,31 @@ def __init__(self, source, lineno, col_offset): _ffi_api.Span, source, lineno, col_offset) +@tvm._ffi.register_object +class EnvFunc(Object): + """Environment function. + + This is a global function object that can be serialized by its name. + """ + def __call__(self, *args): + return _ffi_api.EnvFuncCall(self, *args) + + @property + def func(self): + return _ffi_api.EnvFuncGetPackedFunc(self) + + @staticmethod + def get(name): + """Get a static env function + + Parameters + ---------- + name : str + The name of the function. + """ + return _ffi_api.EnvFuncGet(name) + + def load_json(json_str): """Load tvm object from json_str. diff --git a/python/tvm/ir/container.py b/python/tvm/ir/container.py index 20bb924aad4ac..11ef107f5514a 100644 --- a/python/tvm/ir/container.py +++ b/python/tvm/ir/container.py @@ -14,41 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Container data structures used across IR variants.""" +"""Additional container data structures used across IR variants.""" import tvm._ffi from tvm.runtime import Object from tvm.runtime.container import getitem_helper from tvm.runtime import _ffi_node_api -from . import _ffi_api - - -@tvm._ffi.register_object -class EnvFunc(Object): - """Environment function. - - This is a global function object that can be serialized by its name. - """ - def __call__(self, *args): - return _ffi_api.EnvFuncCall(self, *args) - - @property - def func(self): - return _ffi_api.EnvFuncGetPackedFunc(self) - - @staticmethod - def get(name): - """Get a static env function - """ - - -@tvm._ffi.register_object -class Range(Object): - """Represent a range in TVM. - - You do not need to create a Range explicitly. - Python lists and tuples will be converted automatically to a Range in API functions. - """ @tvm._ffi.register_object diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index c5b146626d8ee..46acd16b80319 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -17,6 +17,7 @@ """Common expressions data structures in the IR.""" import tvm._ffi + from .base import Node from . import _ffi_api @@ -89,3 +90,12 @@ def __call__(self, *args): arg_types = [type(x) for x in args] raise RuntimeError( "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types)) + + +@tvm._ffi.register_object +class Range(Node): + """Represent a range in TVM. + + You do not need to create a Range explicitly. + Python lists and tuples will be converted automatically to a Range in API functions. + """ diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index c08b3a54f1aca..4cc7f4f8082d0 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -15,16 +15,15 @@ # specific language governing permissions and limitations # under the License. """Developer API of IR node builder make function.""" +from tvm._ffi.base import string_types from tvm.runtime import ObjectGeneric, DataType - -from ._ffi.base import string_types +from tvm.ir import container as _container from . import api as _api from . import stmt as _stmt from . import expr as _expr from . import make as _make from . import ir_pass as _pass -from . import container as _container from .expr import Call as _Call class WithScope(object): diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 270c38e4f523d..c2f1df915509f 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -16,9 +16,9 @@ # under the License. """The interface of expr function exposed from C++.""" import tvm._ffi +from tvm.ir import container as _container from ... import build_module as _build -from ... import container as _container @tvm._ffi.register_func("relay.backend.lower") diff --git a/python/tvm/relay/op/nn/util.py b/python/tvm/relay/op/nn/util.py index ba536ad399360..323ef7f9310e2 100644 --- a/python/tvm/relay/op/nn/util.py +++ b/python/tvm/relay/op/nn/util.py @@ -16,8 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-variable """NN operator common utilities""" -from __future__ import absolute_import -from .... import container +from tvm.ir import container + def get_pad_tuple2d(padding): """Common code to get the pad option diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index d779b5979b186..650bf9d1aab18 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -19,11 +19,11 @@ from tvm._ffi.base import string_types from tvm.runtime import Object, convert +from tvm.ir import container as _container from . import _api_internal from . import tensor as _tensor from . import expr as _expr -from . import container as _container @tvm._ffi.register_object diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 2fb4cb0be442a..29e578b08b118 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -109,7 +109,7 @@ def test_type_relation(): args = tvm.convert([tp, tf, tt]) num_inputs = 2 - func = tvm.get_env_func("tvm.relay.type_relation.Broadcast") + func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") attrs = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) tr = relay.TypeRelation(func, args, num_inputs, attrs) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 63e3ab86d7eba..bdc032e8b65ce 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -81,8 +81,8 @@ def test_func_type_alpha_equal(): tp3 = relay.TypeVar("v3", relay.TypeKind.ShapeVar) tp4 = relay.TypeVar("v3", relay.TypeKind.ShapeVar) - broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast") - identity = tvm.get_env_func("tvm.relay.type_relation.Identity") + broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") + identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity") tr1 = relay.TypeRelation(broadcast, tvm.convert([tp1, tp3]), 1, None) tr2 = relay.TypeRelation(broadcast, tvm.convert([tp2, tp4]), 1, None) @@ -157,8 +157,8 @@ def test_type_relation_alpha_equal(): # functions are compared only by pointer equality so # we need to be sure to use the same pointers - broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast") - identity = tvm.get_env_func("tvm.relay.type_relation.Identity") + broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") + identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity") attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) diff --git a/tests/python/relay/test_pass_check_kind.py b/tests/python/relay/test_pass_check_kind.py index a6655e66afb54..62a92040ff166 100644 --- a/tests/python/relay/test_pass_check_kind.py +++ b/tests/python/relay/test_pass_check_kind.py @@ -150,7 +150,7 @@ def test_invalid_relation_kind(): tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint) args = tvm.convert([tp1, tp2, tp3]) - func = tvm.get_env_func("tvm.relay.type_relation.Broadcast") + func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") tr = relay.TypeRelation(func, args, 2, None) check_kind(tr) @@ -217,7 +217,7 @@ def test_func_with_invalid_relation(): tp2 = relay.TypeVar('tp2', relay.TypeKind.ShapeVar) tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint) - func = tvm.get_env_func("tvm.relay.type_relation.Identity") + func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity") tr = relay.TypeRelation(func, tvm.convert([tp2, tp3]), 1, None) tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr])) diff --git a/tests/python/relay/test_type_functor.py b/tests/python/relay/test_type_functor.py index d09a8938bb544..b84a49ae3fdaa 100644 --- a/tests/python/relay/test_type_functor.py +++ b/tests/python/relay/test_type_functor.py @@ -64,7 +64,7 @@ def test_tuple_type(): def test_type_relation(): - func = tvm.get_env_func('tvm.relay.type_relation.Broadcast') + func = tvm.ir.EnvFunc.get('tvm.relay.type_relation.Broadcast') attrs = tvm.make.node('attrs.TestAttrs', name='attr', padding=(3,4)) tp = TypeVar('tp') tf = FuncType([], TupleType([]), [], []) diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 11274181af734..118066e7cf529 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -20,7 +20,7 @@ def make_rel(name, args, num_inputs=None, attrs=None): - func = tvm.get_env_func("tvm.relay.type_relation." + name) + func = tvm.ir.EnvFunc.get("tvm.relay.type_relation." + name) if num_inputs is None: num_inputs = len(args) - 1 return relay.ty.TypeRelation(func, args, num_inputs, attrs) diff --git a/tests/python/unittest/test_lang_reflection.py b/tests/python/unittest/test_lang_reflection.py index a0b844ef864a6..b971e386cfc74 100644 --- a/tests/python/unittest/test_lang_reflection.py +++ b/tests/python/unittest/test_lang_reflection.py @@ -96,7 +96,7 @@ def test(x): return x + 1 f = tvm.get_global_func("test.env_func") - x = tvm.get_env_func("test.env_func") + x = tvm.ir.EnvFunc.get("test.env_func") assert x.name == "test.env_func" json_str = tvm.ir.save_json([x]) y = tvm.ir.load_json(json_str)[0] @@ -110,7 +110,7 @@ def test(x): assert x.padding[1].value == 4 assert x.axis == 10 x = tvm.ir.load_json(tvm.ir.save_json(x)) - assert isinstance(x.func, tvm.container.EnvFunc) + assert isinstance(x.func, tvm.ir.EnvFunc) assert x.func(10) == 11