Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/relax/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import tvm
from tvm.ir import Op
from tvm.meta_schedule.utils import derived_object
from tvm.runtime import Object
from tvm.runtime.support import derived_object

from ..ir.module import IRModule
from . import _ffi_api
Expand All @@ -31,14 +31,14 @@
BindingBlock,
Call,
Constant,
Id,
DataflowBlock,
DataflowVar,
DataTypeImm,
Expr,
ExternFunc,
Function,
GlobalVar,
Id,
If,
MatchCast,
PrimValue,
Expand Down
137 changes: 137 additions & 0 deletions python/tvm/runtime/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Runtime support infra of TVM."""

import re
from typing import TypeVar

import tvm.ffi

Expand Down Expand Up @@ -67,3 +68,139 @@ def _regex_match(regex_pattern: str, match_against: str) -> bool:
"""
match = re.match(regex_pattern, match_against)
return match is not None


T = TypeVar("T")


def derived_object(cls: type[T]) -> type[T]:
"""A decorator to register derived subclasses for TVM objects.

Parameters
----------
cls : type
The derived class to be registered.

Returns
-------
cls : type
The decorated TVM object.

Example
-------
.. code-block:: python

@register_object("meta_schedule.PyRunner")
class _PyRunner(meta_schedule.Runner):
def __init__(self, f_run: Callable = None):
self.__init_handle_by_constructor__(_ffi_api.RunnerPyRunner, f_run)

class PyRunner:
_tvm_metadata = {
"cls": _PyRunner,
"methods": ["run"]
}
def run(self, runner_inputs):
raise NotImplementedError

@derived_object
class LocalRunner(PyRunner):
def run(self, runner_inputs):
...
"""

import functools # pylint: disable=import-outside-toplevel
import weakref # pylint: disable=import-outside-toplevel

def _extract(inst: type, name: str):
"""Extract function from intrinsic class."""

def method(*args, **kwargs):
return getattr(inst, name)(*args, **kwargs)

for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]):
# extract functions that differ from the base class
if not hasattr(base_cls, name):
continue
if getattr(base_cls, name) is getattr(inherit_cls, name) and name != "__str__":
continue
return method

# for task scheduler return None means calling default function
# otherwise it will trigger a TVMError of method not implemented
# on the c++ side when you call the method, __str__ not required
return None

assert isinstance(cls.__base__, type)
if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": # type: ignore
raise TypeError(
(
f"Inheritance from a decorated object `{cls.__name__}` is not allowed. "
f"Please inherit from `{cls.__name__}._cls`."
)
)
assert hasattr(
cls, "_tvm_metadata"
), "Please use the user-facing method overriding class, i.e., PyRunner."

base = cls.__base__
metadata = getattr(base, "_tvm_metadata")
fields = metadata.get("fields", [])
methods = metadata.get("methods", [])

class TVMDerivedObject(metadata["cls"]): # type: ignore
"""The derived object to avoid cyclic dependency."""

_cls = cls
_type = "TVMDerivedObject"

def __init__(self, *args, **kwargs):
"""Constructor."""
self._inst = cls(*args, **kwargs)

super().__init__(
# the constructor's parameters, builder, runner, etc.
*[getattr(self._inst, name) for name in fields],
# the function methods, init_with_tune_context, build, run, etc.
*[_extract(self._inst, name) for name in methods],
)

# for task scheduler hybrid funcs in c++ & python side
# using weakref to avoid cyclic dependency
self._inst._outer = weakref.ref(self)

def __getattr__(self, name):
import inspect # pylint: disable=import-outside-toplevel

try:
# fall back to instance attribute if there is not any
# return self._inst.__getattribute__(name)
result = self._inst.__getattribute__(name)
except AttributeError:
result = super(TVMDerivedObject, self).__getattr__(name)

if inspect.ismethod(result):

def method(*args, **kwargs):
return result(*args, **kwargs)

# set __own__ to aviod implicit deconstruction
setattr(method, "__own__", self)
return method

return result

def __setattr__(self, name, value):
if name not in ["_inst", "key", "handle"]:
self._inst.__setattr__(name, value)
else:
super(TVMDerivedObject, self).__setattr__(name, value)

functools.update_wrapper(TVMDerivedObject.__init__, cls.__init__) # type: ignore
TVMDerivedObject.__name__ = cls.__name__
TVMDerivedObject.__doc__ = cls.__doc__
TVMDerivedObject.__module__ = cls.__module__
for key, value in cls.__dict__.items():
if isinstance(value, (classmethod, staticmethod)):
setattr(TVMDerivedObject, key, value)
return TVMDerivedObject
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,4 @@
from . import stmt_functor
from .build import build
from .pipeline import get_tir_pipeline, get_default_tir_pipeline
from .functor import PyStmtExprVisitor, PyStmtExprMutator
Loading
Loading