diff --git a/docs/reference/api/python/relax/op.rst b/docs/reference/api/python/relax/op.rst index 21f638442a84..922af768f50f 100644 --- a/docs/reference/api/python/relax/op.rst +++ b/docs/reference/api/python/relax/op.rst @@ -70,3 +70,4 @@ tvm.relax.op.op_attrs ********************* .. automodule:: tvm.relax.op.op_attrs :members: + :exclude-members: Attrs diff --git a/docs/reference/api/python/tir/transform.rst b/docs/reference/api/python/tir/transform.rst index 8ce641b6d3f6..29f1bcbbf036 100644 --- a/docs/reference/api/python/tir/transform.rst +++ b/docs/reference/api/python/tir/transform.rst @@ -20,4 +20,5 @@ tvm.tir.transform ----------------- .. automodule:: tvm.tir.transform :members: + :exclude-members: Attrs :imported-members: diff --git a/ffi/src/ffi/extra/serialization.cc b/ffi/src/ffi/extra/serialization.cc index 8d9df03361c2..ea9a96b696ec 100644 --- a/ffi/src/ffi/extra/serialization.cc +++ b/ffi/src/ffi/extra/serialization.cc @@ -408,9 +408,20 @@ class ObjectGraphDeserializer { Any FromJSONGraph(const json::Value& value) { return ObjectGraphDeserializer::Deserialize(value); } +// string version of the api +Any FromJSONGraphString(const String& value) { return FromJSONGraph(json::Parse(value)); } + +String ToJSONGraphString(const Any& value, const Any& metadata) { + return json::Stringify(ToJSONGraph(value, metadata)); +} + TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.ToJSONGraph", ToJSONGraph).def("ffi.FromJSONGraph", FromJSONGraph); + refl::GlobalDef() + .def("ffi.ToJSONGraph", ToJSONGraph) + .def("ffi.ToJSONGraphString", ToJSONGraphString) + .def("ffi.FromJSONGraph", FromJSONGraph) + .def("ffi.FromJSONGraphString", FromJSONGraphString); refl::EnsureTypeAttrColumn("__data_to_json__"); refl::EnsureTypeAttrColumn("__data_from_json__"); }); diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index cce78e9fd615..337f8dc4cbc2 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -51,16 +51,19 @@ struct CallTIRWithGradAttrs : public AttrsNodeReflAdapter /*! \brief Attributes used in call_tir_inplace */ struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter { + /*! + * \brief Indices that describe which input corresponds to which output. + * + * If the `i`th member has the value `k` >= 0, then that means that input `k` should be used to + * store the `i`th output. If an element has the value -1, that means a new tensor should be + * allocated for that output. + */ Array inplace_indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro( - "inplace_indices", &CallTIRInplaceAttrs::inplace_indices, - "Indices that describe which input corresponds to which output. If the `i`th member " - "has the value `k` >= 0, then that means that input `k` should be used to store the " - "`i`th output. If an element has the value -1, that means a new tensor should be " - "allocated for that output."); + refl::ObjectDef().def_ro("inplace_indices", + &CallTIRInplaceAttrs::inplace_indices); } static constexpr const char* _type_key = "relax.attrs.CallTIRInplaceAttrs"; @@ -69,16 +72,19 @@ struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in call_inplace_packed */ struct CallInplacePackedAttrs : public AttrsNodeReflAdapter { + /*! + * \brief Indices that describe which input corresponds to which output. + * + * If the `i`th member has the value `k` >= 0, then that means that input `k` should be used to + * store the `i`th output. If an element has the value -1, that means the output will be newly + * allocated. + */ Array inplace_indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro( - "inplace_indices", &CallInplacePackedAttrs::inplace_indices, - "Indices that describe which input corresponds to which output. If the `i`th member " - "has the value `k` >= 0, then that means that input `k` should be used to store the " - "`i`th output. If an element has the value -1, that means the output will be newly " - "allocated."); + refl::ObjectDef().def_ro("inplace_indices", + &CallInplacePackedAttrs::inplace_indices); } static constexpr const char* _type_key = "relax.attrs.CallInplacePackedAttrs"; diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index b19bcab4c3ef..de3fb0bbad2c 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -65,10 +65,11 @@ class DocNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("source_paths", &DocNode::source_paths); + refl::ObjectDef().def_rw("source_paths", &DocNode::source_paths); } static constexpr const char* _type_key = "script.printer.Doc"; + static constexpr bool _type_mutable = true; TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object); @@ -174,7 +175,7 @@ class StmtDocNode : public DocNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("comment", &StmtDocNode::comment); + refl::ObjectDef().def_rw("comment", &StmtDocNode::comment); } static constexpr const char* _type_key = "script.printer.StmtDoc"; diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index dbb4087f325f..328bb052b87f 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -22,6 +22,7 @@ from . import _ffi_api +@tvm.ffi.register_object("arith.IterMapExpr") class IterMapExpr(PrimExpr): """Base class of all IterMap expressions.""" @@ -89,6 +90,11 @@ def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) +@tvm.ffi.register_object("arith.IterMapResult") +class IterMapResult(Object): + """Result of iter map detection.""" + + class IterMapLevel(IntEnum): """Possible kinds of iter mapping check level.""" diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index 9aa5bde93380..7bd88df5f6f4 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -194,6 +194,7 @@ def ndim(self) -> int: return len(self.shape) +@tvm.ffi.register_object("msc.core.BaseJoint") class BaseJoint(Object): """Base class of all MSC Nodes.""" @@ -561,6 +562,7 @@ def has_attr(self, key: str) -> bool: return bool(_ffi_api.WeightJointHasAttr(self, key)) +@tvm.ffi.register_object("msc.core.BaseGraph") class BaseGraph(Object): """Base class of all MSC Graphs.""" @@ -955,7 +957,7 @@ def visualize(self, path: Optional[str] = None) -> str: @tvm.ffi.register_object("msc.core.WeightGraph") -class WeightGraph(Object): +class WeightGraph(BaseGraph): """The WeightGraph Parameters diff --git a/python/tvm/ffi/__init__.py b/python/tvm/ffi/__init__.py index b507064e34d9..43a20e751c29 100644 --- a/python/tvm/ffi/__init__.py +++ b/python/tvm/ffi/__init__.py @@ -30,6 +30,7 @@ from .ndarray import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, hexagon, webgpu from .ndarray import from_dlpack, NDArray, Shape from .container import Array, Map +from . import serialization from . import testing diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index cbff3fecf135..8c9df19642b0 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -426,3 +426,5 @@ def _convert_to_ffi_func(object pyfunc): _STR_CONSTRUCTOR = _get_global_func("ffi.String", False) _BYTES_CONSTRUCTOR = _get_global_func("ffi.Bytes", False) +_OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True) +_OBJECT_TO_JSON_GRAPH_STR = _get_global_func("ffi.ToJSONGraphString", True) diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi index 4efedf35d8f4..7df5f7a19aff 100644 --- a/python/tvm/ffi/cython/object.pxi +++ b/python/tvm/ffi/cython/object.pxi @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import warnings _CLASS_OBJECT = None _FUNC_CONVERT_TO_OBJECT = None + def _set_class_object(cls): global _CLASS_OBJECT _CLASS_OBJECT = cls @@ -32,31 +34,15 @@ def __object_repr__(obj): return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" -def __object_save_json__(obj): - """Object repr function that can be overridden by assigning to it""" - raise NotImplementedError("JSON serialization depends on downstream init") - - -def __object_load_json__(json_str): - """Object repr function that can be overridden by assigning to it""" - raise NotImplementedError("JSON serialization depends on downstream init") - - -def __object_dir__(obj): - """Object dir function that can be overridden by assigning to it""" - return [] - - -def __object_getattr__(obj, name): - """Object getattr function that can be overridden by assigning to it""" - raise AttributeError() - - def _new_object(cls): """Helper function for pickle""" return cls.__new__(cls) +_OBJECT_FROM_JSON_GRAPH_STR = None +_OBJECT_TO_JSON_GRAPH_STR = None + + class ObjectGeneric: """Base class for all classes that can be converted to object.""" @@ -107,34 +93,24 @@ cdef class Object: return (_new_object, (cls,), self.__getstate__()) def __getstate__(self): + if _OBJECT_TO_JSON_GRAPH_STR is None: + raise RuntimeError("ffi.ToJSONGraphString is not registered, make sure build project with extra API") if not self.__chandle__() == 0: # need to explicit convert to str in case String # returned and triggered another infinite recursion in get state - return {"handle": str(__object_save_json__(self))} + return {"handle": str(_OBJECT_TO_JSON_GRAPH_STR(self, None))} return {"handle": None} def __setstate__(self, state): # pylint: disable=assigning-non-slot, assignment-from-no-return + if _OBJECT_FROM_JSON_GRAPH_STR is None: + raise RuntimeError("ffi.FromJSONGraphString is not registered, make sure build project with extra API") handle = state["handle"] if handle is not None: - self.__init_handle_by_constructor__(__object_load_json__, handle) + self.__init_handle_by_constructor__(_OBJECT_FROM_JSON_GRAPH_STR, handle) else: self.chandle = NULL - def __getattr__(self, name): - if self.chandle == NULL: - raise AttributeError(f"{type(self)} has no attribute {name}") - try: - return __object_getattr__(self, name) - except AttributeError: - raise AttributeError(f"{type(self)} has no attribute {name}") - - def __dir__(self): - # exception safety handling for chandle=None - if self.chandle == NULL: - return [] - return __object_dir__(self) - def __repr__(self): # exception safety handling for chandle=None if self.chandle == NULL: @@ -147,9 +123,6 @@ cdef class Object: def __ne__(self, other): return not self.__eq__(other) - def __init_handle_by_load_json__(self, json_str): - raise NotImplementedError("JSON serialization depends on downstream init") - def __init_handle_by_constructor__(self, fconstructor, *args): """Initialize the handle by calling constructor function. @@ -269,6 +242,15 @@ def _object_type_key_to_index(str type_key): return tidx return None +cdef inline str _type_index_to_key(int32_t tindex): + """get the type key of object class""" + cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(tindex) + cdef const TVMFFIByteArray* type_key + if info == NULL: + return "" + type_key = &(info.type_key) + return py_str(PyBytes_FromStringAndSize(type_key.data, type_key.size)) + cdef inline object make_ret_object(TVMFFIAny result): global OBJECT_TYPE @@ -284,10 +266,14 @@ cdef inline object make_ret_object(TVMFFIAny result): (obj).chandle = result.v_obj return cls.__from_tvm_ffi_object__(cls, obj) obj = cls.__new__(cls) - else: - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) - else: - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + (obj).chandle = result.v_obj + return obj + + # object is not found in registered entry + # in this case we need to report an warning + type_key = _type_index_to_key(tindex) + warnings.warn(f"Returning type `{type_key}` which is not registered via register_object, fallback to Object") + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) (obj).chandle = result.v_obj return obj diff --git a/python/tvm/ffi/serialization.py b/python/tvm/ffi/serialization.py new file mode 100644 index 000000000000..25d9bcefb828 --- /dev/null +++ b/python/tvm/ffi/serialization.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Serialization related utilities to enable some object can be pickled""" + +from typing import Optional, Any +from . import _ffi_api + + +def to_json_graph_str(obj: Any, metadata: Optional[dict] = None): + """ + Dump an object to a JSON graph string. + + The JSON graph string is a string representation of of the object + graph includes the reference information of same objects, which can + be used for serialization and debugging. + + Parameters + ---------- + obj : Any + The object to save. + + metadata : Optional[dict], optional + Extra metadata to save into the json graph string. + + Returns + ------- + json_str : str + The JSON graph string. + """ + return _ffi_api.ToJSONGraphString(obj, metadata) + + +def from_json_graph_str(json_str: str): + """ + Load an object from a JSON graph string. + + The JSON graph string is a string representation of of the object + graph that also includes the reference information. + + Parameters + ---------- + json_str : str + The JSON graph string to load. + + Returns + ------- + obj : Any + The loaded object. + """ + return _ffi_api.FromJSONGraphString(json_str) + + +__all__ = ["from_json_graph_str", "to_json_graph_str"] diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index e7de1a9f909b..cab982f4e783 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -41,7 +41,7 @@ def get_int_tuple(self, key): ------- value: Tuple of int """ - return tuple(x if isinstance(x, int) else x.value for x in self.__getattr__(key)) + return tuple(x if isinstance(x, int) else x.value for x in getattr(self, key)) def get_int(self, key): """Get a python int value of a key @@ -54,7 +54,7 @@ def get_int(self, key): ------- value: int """ - return self.__getattr__(key) + return getattr(self, key) def get_str(self, key): """Get a python int value of a key @@ -67,10 +67,10 @@ def get_str(self, key): ------- value: int """ - return self.__getattr__(key) + return getattr(self, key) def __getitem__(self, item): - return self.__getattr__(item) + return getattr(self, item) @tvm.ffi.register_object("ir.DictAttrs") @@ -101,6 +101,12 @@ def get(self, key, default=None): def __contains__(self, k): return self._dict().__contains__(k) + def __getattr__(self, name): + try: + return self._dict().__getitem__(name) + except KeyError: + raise AttributeError(f"DictAttrs has no attribute {name}") + def items(self): """Get items from the map.""" return self._dict().items() diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 633c2c6790da..eca885e03acb 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -326,7 +326,7 @@ def __init__(self, name_hint: str = ""): @register_df_node -class DataflowVarPattern(DFPattern): +class DataflowVarPattern(VarPattern): """A pattern for DataflowVar. Parameters diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 9ddaf52e722c..ee9caf3a835b 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -1177,6 +1177,11 @@ def const( return Constant(value) +@tvm.ffi.register_object("relax.TEPlaceholderOp") +class TEPlaceholderOp(tvm.te.tensor.Operation): + """The placeholder op that represents a relax expression.""" + + def te_tensor( value: Expr, tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr], name: str = "rxplaceholder" ): diff --git a/python/tvm/relax/op/_op_gradient.py b/python/tvm/relax/op/_op_gradient.py index 41eaa5de5008..fd80f1e31333 100644 --- a/python/tvm/relax/op/_op_gradient.py +++ b/python/tvm/relax/op/_op_gradient.py @@ -829,8 +829,8 @@ def cumsum_grad( The "reversed" cumsum along the same axis. Implemented by some tricks now. """ - axis = orig_call.attrs["axis"] - dtype = orig_call.attrs["dtype"] + axis = orig_call.attrs.axis + dtype = orig_call.attrs.dtype x_shape = _get_shape(orig_call.args[0]) if axis is not None: diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 864eb3fec709..bb134f114855 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -624,6 +624,7 @@ def index_put( Examples -------- .. code-block:: python + # inputs data = torch.zeros(3, 3) indices = (torch.tensor([0, 2]), torch.tensor([1, 1])) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 3e0f87c48751..9c15cdd96613 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -202,3 +202,158 @@ class FlipAttrs(Attrs): @tvm.ffi.register_object("relax.attrs.PadAttrs") class PadAttrs(Attrs): """Attributes used in pad operator""" + + +@tvm.ffi.register_object("relax.attrs.MultinomialFromUniformAttrs") +class MultinomialFromUniformAttrs(Attrs): + """Attributes for multinomial_from_uniform operator""" + + +@tvm.ffi.register_object("relax.attrs.CallInplacePackedAttrs") +class CallInplacePackedAttrs(Attrs): + """Attributes used in call_inplace_packed operator""" + + +@tvm.ffi.register_object("relax.attrs.CallTIRInplaceAttrs") +class CallTIRInplaceAttrs(Attrs): + """Attributes used in call_tir_inplace operator""" + + +@tvm.ffi.register_object("relax.attrs.ToVDeviceAttrs") +class ToVDeviceAttrs(Attrs): + """Attributes used in to_vdevice operator""" + + +@tvm.ffi.register_object("relax.attrs.HintOnDeviceAttrs") +class HintOnDeviceAttrs(Attrs): + """Attributes used in hint_on_device operator""" + + +@tvm.ffi.register_object("relax.attrs.ScatterCollectiveAttrs") +class ScatterCollectiveAttrs(Attrs): + """Attributes used in scatter collective operators""" + + +@tvm.ffi.register_object("relax.attrs.AttentionAttrs") +class AttentionAttrs(Attrs): + """Attributes used in attention operator""" + + +@tvm.ffi.register_object("relax.attrs.Conv1DAttrs") +class Conv1DAttrs(Attrs): + """Attributes for nn.conv1d""" + + +@tvm.ffi.register_object("relax.attrs.Conv1DTransposeAttrs") +class Conv1DTransposeAttrs(Attrs): + """Attributes for nn.conv1d_transpose""" + + +@tvm.ffi.register_object("relax.attrs.Pool1DAttrs") +class Pool1DAttrs(Attrs): + """Attributes for nn.max_pool1d and nn.avg_pool1d""" + + +@tvm.ffi.register_object("relax.attrs.Pool3DAttrs") +class Pool3DAttrs(Attrs): + """Attributes for nn.max_pool3d and nn.avg_pool3d""" + + +@tvm.ffi.register_object("relax.attrs.AdaptivePool1DAttrs") +class AdaptivePool1DAttrs(Attrs): + """Attributes for 1d adaptive pool operator""" + + +@tvm.ffi.register_object("relax.attrs.AdaptivePool3DAttrs") +class AdaptivePool3DAttrs(Attrs): + """Attributes for 3d adaptive pool operator""" + + +@tvm.ffi.register_object("relax.attrs.LeakyReluAttrs") +class LeakyReluAttrs(Attrs): + """Attributes used in leaky_relu operator""" + + +@tvm.ffi.register_object("relax.attrs.SoftplusAttrs") +class SoftplusAttrs(Attrs): + """Attributes used in softplus operator""" + + +@tvm.ffi.register_object("relax.attrs.PReluAttrs") +class PReluAttrs(Attrs): + """Attributes used in prelu operator""" + + +@tvm.ffi.register_object("relax.attrs.PixelShuffleAttrs") +class PixelShuffleAttrs(Attrs): + """Attributes used in pixel_shuffle operator""" + + +@tvm.ffi.register_object("relax.attrs.GroupNormAttrs") +class GroupNormAttrs(Attrs): + """Attributes used in group_norm operator""" + + +@tvm.ffi.register_object("relax.attrs.RMSNormAttrs") +class RMSNormAttrs(Attrs): + """Attributes used in rms_norm operator""" + + +@tvm.ffi.register_object("relax.attrs.NLLLossAttrs") +class NLLLossAttrs(Attrs): + """Attributes used in nll_loss operator""" + + +@tvm.ffi.register_object("relax.attrs.AllReduceAttrs") +class AllReduceAttrs(Attrs): + """Attributes used in allreduce operator""" + + +@tvm.ffi.register_object("relax.attrs.AllGatherAttrs") +class AllGatherAttrs(Attrs): + """Attributes used in allgather operator""" + + +@tvm.ffi.register_object("relax.attrs.WrapParamAttrs") +class WrapParamAttrs(Attrs): + """Attributes used in wrap_param operator""" + + +@tvm.ffi.register_object("relax.attrs.QuantizeAttrs") +class QuantizeAttrs(Attrs): + """Attributes used in quantize/dequantize operators""" + + +@tvm.ffi.register_object("relax.attrs.GatherElementsAttrs") +class GatherElementsAttrs(Attrs): + """Attributes for gather_elements operator""" + + +@tvm.ffi.register_object("relax.attrs.GatherNDAttrs") +class GatherNDAttrs(Attrs): + """Attributes for gather_nd operator""" + + +@tvm.ffi.register_object("relax.attrs.MeshgridAttrs") +class MeshgridAttrs(Attrs): + """Attributes for meshgrid operator""" + + +@tvm.ffi.register_object("relax.attrs.ScatterElementsAttrs") +class ScatterElementsAttrs(Attrs): + """Attributes for scatter_elements operator""" + + +@tvm.ffi.register_object("relax.attrs.ScatterNDAttrs") +class ScatterNDAttrs(Attrs): + """Attributes for scatter_nd operator""" + + +@tvm.ffi.register_object("relax.attrs.SliceScatterAttrs") +class SliceScatterAttrs(Attrs): + """Attributes for slice_scatter operator""" + + +@tvm.ffi.register_object("relax.attrs.OneHotAttrs") +class OneHotAttrs(Attrs): + """Attributes for one_hot operator""" diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index aef9ded9cc0d..4a0edd449c24 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -28,14 +28,6 @@ def AsRepr(obj): return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" -def NodeListAttrNames(obj): - return lambda x: 0 - - -def NodeGetAttr(obj, name): - raise AttributeError() - - def SaveJSON(obj): raise RuntimeError("Do not support object serialization in runtime only mode") diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index 688682d197c5..b2fcddc40ad6 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -22,17 +22,6 @@ from . import _ffi_node_api -def __object_dir__(obj): - class_names = dir(obj.__class__) - fnames = _ffi_node_api.NodeListAttrNames(obj) - size = fnames(-1) - return sorted([fnames(i) for i in range(size)] + class_names) - - tvm.ffi.core._set_class_object(Object) # override the default repr function for tvm.ffi.core.Object tvm.ffi.core.__object_repr__ = _ffi_node_api.AsRepr -tvm.ffi.core.__object_save_json__ = _ffi_node_api.SaveJSON -tvm.ffi.core.__object_load_json__ = _ffi_node_api.LoadJSON -tvm.ffi.core.__object_getattr__ = _ffi_node_api.NodeGetAttr -tvm.ffi.core.__object_dir__ = __object_dir__ diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 02a67e916bc0..bf468b17ec18 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -26,25 +26,12 @@ from . import _ffi_api +@register_object("script.printer.Doc") class Doc(Object): """Base class of all Docs""" - @property - def source_paths(self) -> Sequence[ObjectPath]: - """ - The list of object paths of the source IR node. - - This is used to trace back to the IR node position where - this Doc is generated, in order to position the diagnostic - message. - """ - return self.__getattr__("source_paths") # pylint: disable=unnecessary-dunder-call - - @source_paths.setter - def source_paths(self, value): - return _ffi_api.DocSetSourcePaths(self, value) # type: ignore # pylint: disable=no-member - +@register_object("script.printer.ExprDoc") class ExprDoc(Doc): """Base class of all expression Docs""" @@ -114,26 +101,10 @@ def __iter__(self): raise RuntimeError(f"{self.__class__} cannot be used as iterable.") +@register_object("script.printer.StmtDoc") class StmtDoc(Doc): """Base class of statement doc""" - @property - def comment(self) -> Optional[str]: - """ - The comment of this doc. - - The actual position of the comment depends on the type of Doc - and also the DocPrinter implementation. It could be on the same - line as the statement, or the line above, or inside the statement - if it spans over multiple lines. - """ - # It has to call the dunder method to avoid infinite recursion - return self.__getattr__("comment") # pylint: disable=unnecessary-dunder-call - - @comment.setter - def comment(self, value): - return _ffi_api.StmtDocSetComment(self, value) # type: ignore # pylint: disable=no-member - @register_object("script.printer.StmtBlockDoc") class StmtBlockDoc(Doc): diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 489ec38ba506..73b995a45e61 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -84,26 +84,6 @@ def ndim(self): """Dimension of the tensor.""" return len(self.shape) - @property - def axis(self): - """Axis of the tensor.""" - return self.__getattr__("axis") - - @property - def op(self): - """The corressponding :py:class:`Operation`.""" - return self.__getattr__("op") - - @property - def value_index(self): - """The output value index the tensor corresponds to.""" - return self.__getattr__("value_index") - - @property - def shape(self): - """The output shape of the tensor.""" - return self.__getattr__("shape") - @property def name(self): op = self.op diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index ea798242b462..620a66351d9c 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -43,3 +43,4 @@ ) from .runner import local_run, rpc_run from .utils import * +from .attrs import * diff --git a/python/tvm/testing/attrs.py b/python/tvm/testing/attrs.py new file mode 100644 index 000000000000..ea6f1b1af65c --- /dev/null +++ b/python/tvm/testing/attrs.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, import-outside-toplevel, unused-variable +"""Testing utilities for attrs""" +from ..ir import Attrs +from ..ffi import register_object + + +@register_object("attrs.TestAttrs") +class TestAttrs(Attrs): + """Attrs used for testing purposes""" + + +__all__ = ["TestAttrs"] diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 81ce63b7972f..93a182ca3bc2 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -23,6 +23,8 @@ from . import _ffi_api from . import function_pass as _fpass +from ... import ir as _ir +from ... import ffi as _ffi def Apply(ftransform): @@ -48,6 +50,11 @@ def _transform(func, mod, ctx): return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") # type: ignore +@_ffi.register_object("tir.transform.LoopPartitionConfig") +class LoopPartitionConfig(_ir.Attrs): + """Config for loop partition pass""" + + def LoopPartition(): """Inject virtual thread loops. @@ -87,6 +94,11 @@ def InjectVirtualThread(): return _ffi_api.InjectVirtualThread() # type: ignore +@_ffi.register_object("tir.transform.InjectDoubleBufferConfig") +class InjectDoubleBufferConfig(_ir.Attrs): + """Config for inject double buffer pass""" + + def InjectDoubleBuffer(): """Inject double buffer statements. @@ -149,6 +161,11 @@ def PointerValueTypeRewrite(): return _ffi_api.PointerValueTypeRewrite() # type: ignore +@_ffi.register_object("tir.transform.UnrollLoopConfig") +class UnrollLoopConfig(_ir.Attrs): + """Config for unroll loop pass""" + + def UnrollLoop(): """Unroll the constant loop marked by unroll. @@ -162,6 +179,11 @@ def UnrollLoop(): return _ffi_api.UnrollLoop() # type: ignore +@_ffi.register_object("tir.transform.ReduceBranchingThroughOvercomputeConfig") +class ReduceBranchingThroughOvercomputeConfig(_ir.Attrs): + """Config for reduce branching through overcompute pass""" + + def ReduceBranchingThroughOvercompute(): """Reduce branching by introducing overcompute @@ -173,6 +195,11 @@ def ReduceBranchingThroughOvercompute(): return _ffi_api.ReduceBranchingThroughOvercompute() # type: ignore +@_ffi.register_object("tir.transform.RemoveNoOpConfig") +class RemoveNoOpConfig(_ir.Attrs): + """Config for remove no op pass""" + + def RemoveNoOp(): """Remove No Op from the Stmt. @@ -277,6 +304,11 @@ def RewriteUnsafeSelect(): return _ffi_api.RewriteUnsafeSelect() # type: ignore +@_ffi.register_object("tir.transform.SimplifyConfig") +class SimplifyConfig(_ir.Attrs): + """Config for simplify pass""" + + def Simplify(): """Run arithmetic simplifications on the statements and expressions. @@ -607,6 +639,11 @@ def VerifyVTCMLimit(limit=None): return _ffi_api.VerifyVTCMLimit(limit) # type: ignore +@_ffi.register_object("tir.transform.HoistIfThenElseConfig") +class HoistIfThenElseConfig(_ir.Attrs): + """Config for hoist if then else pass""" + + # pylint: disable=no-else-return,inconsistent-return-statements def HoistIfThenElse(variant: Optional[str] = None): """Hoist loop-invariant IfThenElse nodes to outside the eligible loops. @@ -686,6 +723,11 @@ class HoistedLetBindings(enum.Flag): """ Enable all hoisting of let bindings """ +@_ffi.register_object("tir.transform.HoistExpressionConfig") +class HoistExpressionConfig(_ir.Attrs): + """Config for hoist expression pass""" + + def HoistExpression(): """Generalized verison of HoistIfThenElse. diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 6db751a80f87..e666b434f8f5 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -33,75 +33,6 @@ using ffi::Any; using ffi::Function; using ffi::PackedArgs; -// Expose to FFI APIs. -void NodeGetAttr(ffi::PackedArgs args, ffi::Any* ret) { - Object* self = const_cast(args[0].cast()); - String field_name = args[1].cast(); - - bool success; - if (field_name == "type_key") { - *ret = self->GetTypeKey(); - success = true; - } else if (!self->IsInstance()) { - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index()); - success = false; - // use new reflection mechanism - if (type_info->metadata != nullptr) { - ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - if (field_name.compare(field_info->name) == 0) { - ffi::reflection::FieldGetter field_getter(field_info); - *ret = field_getter(self); - success = true; - } - }); - } - } else { - // specially handle dict attr - DictAttrsNode* dnode = static_cast(self); - auto it = dnode->dict.find(field_name); - if (it != dnode->dict.end()) { - success = true; - *ret = (*it).second; - } else { - success = false; - } - } - if (!success) { - TVM_FFI_THROW(AttributeError) << self->GetTypeKey() << " object has no attribute `" - << field_name << "`"; - } -} - -void NodeListAttrNames(ffi::PackedArgs args, ffi::Any* ret) { - Object* self = const_cast(args[0].cast()); - - std::vector names; - if (!self->IsInstance()) { - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index()); - if (type_info->metadata != nullptr) { - // use new reflection mechanism - ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - names.push_back(std::string(field_info->name.data, field_info->name.size)); - }); - } - } else { - // specially handle dict attr - DictAttrsNode* dnode = static_cast(self); - for (const auto& kv : dnode->dict) { - names.push_back(kv.first); - } - } - - *ret = ffi::Function::FromPacked([names](ffi::PackedArgs args, ffi::Any* rv) { - int64_t i = args[0].cast(); - if (i == -1) { - *rv = static_cast(names.size()); - } else { - *rv = names[i]; - } - }); -} - // API function to make node. // args format: // key1, value1, ..., key_n, value_n @@ -123,10 +54,7 @@ void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def_packed("node.NodeGetAttr", NodeGetAttr) - .def_packed("node.NodeListAttrNames", NodeListAttrNames) - .def_packed("node.MakeNode", MakeNode); + refl::GlobalDef().def_packed("node.MakeNode", MakeNode); }); } // namespace tvm diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h index bc4b90a37333..aa7cb9db538e 100644 --- a/src/relax/ir/emit_te.h +++ b/src/relax/ir/emit_te.h @@ -52,7 +52,7 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { .def_ro("dtype", &RXPlaceholderOpNode::dtype); } - static constexpr const char* _type_key = "RXPlaceholderOp"; + static constexpr const char* _type_key = "relax.TEPlaceholderOp"; TVM_DECLARE_FINAL_OBJECT_INFO(RXPlaceholderOpNode, te::PlaceholderOpNode); }; diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index d89114c68abd..1548ea1da625 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -82,7 +82,7 @@ struct HoistExpressionConfigNode : public AttrsNodeReflAdapter(flag) & hoisted_let_bindings; } - static constexpr const char* _type_key = "tir.transforms.HoistExpressionConfig"; + static constexpr const char* _type_key = "tir.transform.HoistExpressionConfig"; TVM_DECLARE_FINAL_OBJECT_INFO(HoistExpressionConfigNode, Object); }; @@ -112,7 +112,7 @@ struct HoistIfThenElseConfigNode : public AttrsNodeReflAdapter R.Tensor((5, "b * 2"), dtype="float32"): b = T.int64() lv: R.Shape([5, b * 2]) = R.shape([5, b * 2]) diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index 6711ccf92f3f..e696cbcf086c 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -413,7 +413,6 @@ def check(client, is_local): get_elem = client.get_function("testing.GetShapeElem") get_size = client.get_function("testing.GetShapeSize") shape = make_shape(2, 3) - assert shape.type_key == "runtime.RPCObjectRef" assert get_elem(shape, 0) == 2 assert get_elem(shape, 1) == 3 assert get_size(shape) == 2