Skip to content

Commit

Permalink
Finish migrate container
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 12, 2020
1 parent 7746404 commit c6e93bc
Show file tree
Hide file tree
Showing 20 changed files with 66 additions and 81 deletions.
2 changes: 1 addition & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# tvm.ir
from .ir import IRModule
from .ir import transform
from .ir import container
from . import ir

# others
Expand All @@ -46,7 +47,6 @@
from . import make
from . import ir_pass
from . import codegen
from . import container
from . import schedule

from . import ir_builder
Expand Down
29 changes: 4 additions & 25 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
from numbers import Integral as _Integral

import tvm._ffi
import tvm.ir

from tvm.runtime import convert, const, DataType
from tvm.ir import container as _container

from ._ffi.base import string_types, TVMError
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs

Expand All @@ -29,7 +32,6 @@
from . import expr as _expr
from . import tensor as _tensor
from . import schedule as _schedule
from . import container as _container
from . import tag as _tag

int8 = "int8"
Expand Down Expand Up @@ -69,29 +71,6 @@ def max_value(dtype):
"""
return _api_internal._max_value(dtype)


def get_env_func(name):
"""Get an EnvFunc by a global name.
Parameters
----------
name: str
The name of the global function.
Returns
-------
env_func : EnvFunc
The result env function.
Note
----
EnvFunc is a Object wrapper around
global function that can be serialized via its name.
This can be used to serialize function field in the language.
"""
return _api_internal._EnvFuncGet(name)


def var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
Expand Down Expand Up @@ -649,7 +628,7 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
raise TypeError("need to be list of ranges")
dom = Range(dom[0], dom[1])

if not isinstance(dom, _container.Range):
if not isinstance(dom, tvm.ir.Range):
raise TypeError("dom need to be Range")
name = name if name else 'iter'
v = var(name)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
import tvm.runtime

from tvm.runtime import Object, ndarray
from tvm.ir import container
from . import api
from . import _api_internal
from . import tensor
from . import schedule
from . import expr
from . import ir_pass
from . import stmt as _stmt
from . import container
from . import codegen
from . import target as _target
from . import make
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
# under the License.
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""

from tvm.ir.container import Array
from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from .. import target as _tgt
from ..container import Array
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numbers

from enum import Enum
from tvm.ir.container import Array

from .util import _internal_assert
from . import calls
Expand All @@ -32,7 +33,6 @@
from ..api import all as _all
from ..api import any as _any

from ..container import Array
from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
from .. import expr as _expr
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/hybrid/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
import logging
import sys
import numpy

from tvm.ir.container import Array
from .. import api as _api
from .. import make as _make
from .. import expr as _expr
from .. import stmt as _stmt
from .._ffi.base import numeric_types
from ..tensor import Tensor
from ..container import Array


#pylint: disable=invalid-name
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
# under the License.
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, load_json, save_json
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .type_relation import TypeCall, TypeRelation
from .tensor_type import TensorType
from .adt import Constructor, TypeData
from .module import IRModule
from .attrs import Attrs
from .container import Array, Map

from . import transform
25 changes: 25 additions & 0 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,31 @@ def __init__(self, source, lineno, col_offset):
_ffi_api.Span, source, lineno, col_offset)


@tvm._ffi.register_object
class EnvFunc(Object):
"""Environment function.
This is a global function object that can be serialized by its name.
"""
def __call__(self, *args):
return _ffi_api.EnvFuncCall(self, *args)

@property
def func(self):
return _ffi_api.EnvFuncGetPackedFunc(self)

@staticmethod
def get(name):
"""Get a static env function
Parameters
----------
name : str
The name of the function.
"""
return _ffi_api.EnvFuncGet(name)


def load_json(json_str):
"""Load tvm object from json_str.
Expand Down
31 changes: 1 addition & 30 deletions python/tvm/ir/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Container data structures used across IR variants."""
"""Additional container data structures used across IR variants."""
import tvm._ffi

from tvm.runtime import Object
from tvm.runtime.container import getitem_helper
from tvm.runtime import _ffi_node_api
from . import _ffi_api


@tvm._ffi.register_object
class EnvFunc(Object):
"""Environment function.
This is a global function object that can be serialized by its name.
"""
def __call__(self, *args):
return _ffi_api.EnvFuncCall(self, *args)

@property
def func(self):
return _ffi_api.EnvFuncGetPackedFunc(self)

@staticmethod
def get(name):
"""Get a static env function
"""


@tvm._ffi.register_object
class Range(Object):
"""Represent a range in TVM.
You do not need to create a Range explicitly.
Python lists and tuples will be converted automatically to a Range in API functions.
"""


@tvm._ffi.register_object
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Common expressions data structures in the IR."""
import tvm._ffi


from .base import Node
from . import _ffi_api

Expand Down Expand Up @@ -89,3 +90,12 @@ def __call__(self, *args):
arg_types = [type(x) for x in args]
raise RuntimeError(
"Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types))


@tvm._ffi.register_object
class Range(Node):
"""Represent a range in TVM.
You do not need to create a Range explicitly.
Python lists and tuples will be converted automatically to a Range in API functions.
"""
5 changes: 2 additions & 3 deletions python/tvm/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""Developer API of IR node builder make function."""
from tvm._ffi.base import string_types
from tvm.runtime import ObjectGeneric, DataType

from ._ffi.base import string_types
from tvm.ir import container as _container

from . import api as _api
from . import stmt as _stmt
from . import expr as _expr
from . import make as _make
from . import ir_pass as _pass
from . import container as _container
from .expr import Call as _Call

class WithScope(object):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# under the License.
"""The interface of expr function exposed from C++."""
import tvm._ffi
from tvm.ir import container as _container

from ... import build_module as _build
from ... import container as _container


@tvm._ffi.register_func("relay.backend.lower")
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# under the License.
# pylint: disable=invalid-name, unused-variable
"""NN operator common utilities"""
from __future__ import absolute_import
from .... import container
from tvm.ir import container


def get_pad_tuple2d(padding):
"""Common code to get the pad option
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from tvm.ir import container as _container

from . import _api_internal
from . import tensor as _tensor
from . import expr as _expr
from . import container as _container


@tvm._ffi.register_object
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_type_relation():
args = tvm.convert([tp, tf, tt])

num_inputs = 2
func = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
attrs = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))

tr = relay.TypeRelation(func, args, num_inputs, attrs)
Expand Down
8 changes: 4 additions & 4 deletions tests/python/relay/test_pass_alpha_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def test_func_type_alpha_equal():
tp3 = relay.TypeVar("v3", relay.TypeKind.ShapeVar)
tp4 = relay.TypeVar("v3", relay.TypeKind.ShapeVar)

broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
identity = tvm.get_env_func("tvm.relay.type_relation.Identity")
broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity")

tr1 = relay.TypeRelation(broadcast, tvm.convert([tp1, tp3]), 1, None)
tr2 = relay.TypeRelation(broadcast, tvm.convert([tp2, tp4]), 1, None)
Expand Down Expand Up @@ -157,8 +157,8 @@ def test_type_relation_alpha_equal():

# functions are compared only by pointer equality so
# we need to be sure to use the same pointers
broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
identity = tvm.get_env_func("tvm.relay.type_relation.Identity")
broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity")

attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_pass_check_kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_invalid_relation_kind():
tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint)
args = tvm.convert([tp1, tp2, tp3])

func = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
tr = relay.TypeRelation(func, args, 2, None)
check_kind(tr)

Expand Down Expand Up @@ -217,7 +217,7 @@ def test_func_with_invalid_relation():
tp2 = relay.TypeVar('tp2', relay.TypeKind.ShapeVar)
tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint)

func = tvm.get_env_func("tvm.relay.type_relation.Identity")
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity")
tr = relay.TypeRelation(func, tvm.convert([tp2, tp3]), 1, None)

tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr]))
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_type_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_tuple_type():


def test_type_relation():
func = tvm.get_env_func('tvm.relay.type_relation.Broadcast')
func = tvm.ir.EnvFunc.get('tvm.relay.type_relation.Broadcast')
attrs = tvm.make.node('attrs.TestAttrs', name='attr', padding=(3,4))
tp = TypeVar('tp')
tf = FuncType([], TupleType([]), [], [])
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_type_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


def make_rel(name, args, num_inputs=None, attrs=None):
func = tvm.get_env_func("tvm.relay.type_relation." + name)
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation." + name)
if num_inputs is None:
num_inputs = len(args) - 1
return relay.ty.TypeRelation(func, args, num_inputs, attrs)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_lang_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test(x):
return x + 1

f = tvm.get_global_func("test.env_func")
x = tvm.get_env_func("test.env_func")
x = tvm.ir.EnvFunc.get("test.env_func")
assert x.name == "test.env_func"
json_str = tvm.ir.save_json([x])
y = tvm.ir.load_json(json_str)[0]
Expand All @@ -110,7 +110,7 @@ def test(x):
assert x.padding[1].value == 4
assert x.axis == 10
x = tvm.ir.load_json(tvm.ir.save_json(x))
assert isinstance(x.func, tvm.container.EnvFunc)
assert isinstance(x.func, tvm.ir.EnvFunc)
assert x.func(10) == 11


Expand Down

0 comments on commit c6e93bc

Please sign in to comment.