diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 974747c77416..381be8514916 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -400,14 +400,6 @@ TVM_DLL PrimExpr operator~(PrimExpr a); */ class RelaxExprNode : public BaseExprNode { public: - /*! - * \brief Stores the result of type inference(type checking). - * - * \note This can be undefined before type inference. - * This value is discarded during serialization. - */ - mutable Type checked_type_ = Type(nullptr); - /*! * \brief Stores the result of structure information of the * expression that encapsulate both static shape and @@ -415,23 +407,6 @@ class RelaxExprNode : public BaseExprNode { */ mutable Optional struct_info_ = Optional(); - /*! - * \return The checked_type - */ - inline const Type& checked_type() const; - /*! - * \brief Check if the inferred(checked) type of the Expr - * is backed by a TTypeNode and return it. - * - * \note This function will thrown an error if the node type - * of this Expr is not TTypeNode. - * - * \return The corresponding TTypeNode pointer. - * \tparam The specific TypeNode we look for. - */ - template - inline const TTypeNode* type_as() const; - static constexpr const char* _type_key = "RelaxExpr"; static constexpr const uint32_t _type_child_slots = 22; TVM_DECLARE_BASE_OBJECT_INFO(RelaxExprNode, BaseExprNode); @@ -463,7 +438,6 @@ class GlobalVarNode : public RelaxExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); v->Visit("span", &span); - v->Visit("_checked_type_", &checked_type_); v->Visit("struct_info_", &struct_info_); } @@ -487,7 +461,7 @@ class GlobalVarNode : public RelaxExprNode { */ class GlobalVar : public RelaxExpr { public: - TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {}); + TVM_DLL explicit GlobalVar(String name_hint, Span span = {}); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelaxExpr, GlobalVarNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode); @@ -747,26 +721,6 @@ class Range : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode); }; -// implementations -inline const Type& RelaxExprNode::checked_type() const { - ICHECK(checked_type_.defined()) << "internal error: the type checker has " - << "not populated the checked_type " - << "field for " << GetRef(this); - return this->checked_type_; -} - -template -inline const TTypeNode* RelaxExprNode::type_as() const { - static_assert(std::is_base_of::value, - "TType must be a special case of type"); - ICHECK(checked_type_.defined()) - << "Type inference for this Expr has not completed. Try to call infer_type pass."; - const TTypeNode* node = checked_type_.as(); - ICHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get " - << checked_type_->GetTypeKey(); - return node; -} - namespace ffi { // Type traits to enable automatic conversion into IntImm, Integer, and Bool // when called through the FFI diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index fa51856a0104..61c170b36639 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -131,7 +131,7 @@ constexpr const char* kGlobalSymbol = "global_symbol"; * \brief Base node of all functions. * * We support several variants of functions throughout the stack. - * All of the functions share the same type system(via checked_type) + * All of the functions share the same type system * to support cross variant calls. * * \sa BaseFunc diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 987fe16207dc..36d07516086f 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -55,7 +55,6 @@ class AndPattern; class NotPattern; class ShapePattern; class StructInfoPattern; -class TypePattern; class DataTypePattern; class AttrPattern; class SameShapeConstraint; @@ -116,8 +115,6 @@ class DFPattern : public ObjectRef { TVM_DLL AttrPattern HasAttr(const Map& attrs) const; /*! \brief Syntatic Sugar for creating a StructInfoPattern */ TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const; - /*! \brief Syntatic Sugar for creating a TypePattern */ - TVM_DLL TypePattern HasType(const Type& type) const; /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */ TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const; /*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */ @@ -742,34 +739,6 @@ class WildcardPattern : public DFPattern { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); }; -/*! - * \brief Pattern for matching a certain type. - * \sa TypePattern - */ -class TypePatternNode : public DFPatternNode { - public: - DFPattern pattern; /*!< The pattern to match */ - Type type; /*!< The type to match */ - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pattern", &pattern); - v->Visit("type", &type); - } - - static constexpr const char* _type_key = "relax.dpl.TypePattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode); -}; - -/*! - * \brief Managed reference to TypePatternNode. - * \sa TypePatternNode - */ -class TypePattern : public DFPattern { - public: - TVM_DLL TypePattern(DFPattern pattern, Type type); - TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode); -}; - /*! * \brief Pattern for matching a certain struct info. * \sa StructInfoPattern diff --git a/include/tvm/relax/dataflow_pattern_functor.h b/include/tvm/relax/dataflow_pattern_functor.h index fb67f3cc4aca..c12ab0326df4 100644 --- a/include/tvm/relax/dataflow_pattern_functor.h +++ b/include/tvm/relax/dataflow_pattern_functor.h @@ -96,7 +96,6 @@ class DFPatternFunctor { virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const StructInfoPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; @@ -132,7 +131,6 @@ class DFPatternFunctor { RELAX_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(StructInfoPatternNode); - RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode); @@ -167,7 +165,6 @@ class DFPatternVisitor : public DFPatternFunctor { void VisitDFPattern_(const TupleGetItemPatternNode* op) override; void VisitDFPattern_(const TuplePatternNode* op) override; void VisitDFPattern_(const StructInfoPatternNode* op) override; - void VisitDFPattern_(const TypePatternNode* op) override; void VisitDFPattern_(const WildcardPatternNode* op) override; void VisitDFPattern_(const VarPatternNode* op) override; diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 08cbd1de538e..6197d1ed280f 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -166,7 +166,6 @@ class CallNode : public ExprNode { v->Visit("attrs", &attrs); v->Visit("sinfo_args", &sinfo_args); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -224,7 +223,6 @@ class TupleNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); - v->Visit("_checked_type_", &checked_type_); v->Visit("struct_info_", &struct_info_); v->Visit("span", &span); } @@ -291,7 +289,6 @@ class TupleGetItemNode : public ExprNode { v->Visit("tuple_value", &tuple); v->Visit("index", &index); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -362,7 +359,6 @@ class ShapeExprNode : public LeafExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("values", &values); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -399,7 +395,6 @@ class VarNode : public LeafExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("vid", &vid); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -440,7 +435,6 @@ class DataflowVarNode : public VarNode { void VisitAttrs(AttrVisitor* v) { v->Visit("vid", &vid); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -492,7 +486,6 @@ class ConstantNode : public LeafExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -540,7 +533,6 @@ class PrimValueNode : public LeafExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -591,7 +583,6 @@ class StringImmNode : public LeafExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -634,7 +625,6 @@ class DataTypeImmNode : public LeafExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -824,7 +814,6 @@ class SeqExprNode : public ExprNode { v->Visit("blocks", &blocks); v->Visit("body", &body); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -889,7 +878,6 @@ class IfNode : public ExprNode { v->Visit("cond", &cond); v->Visit("true_branch", &true_branch); v->Visit("false_branch", &false_branch); - v->Visit("_checked_type_", &checked_type_); v->Visit("struct_info_", &struct_info_); v->Visit("span", &span); } @@ -966,7 +954,6 @@ class FunctionNode : public BaseFuncNode { v->Visit("ret_struct_info", &ret_struct_info); v->Visit("attrs", &attrs); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -1071,7 +1058,6 @@ class ExternFuncNode : public BaseFuncNode { void VisitAttrs(AttrVisitor* v) { v->Visit("global_symbol", &global_symbol); v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index c77383bdbf3d..7634bc34a26f 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -312,7 +312,7 @@ void PostOrderVisit(const Expr& node, std::function fvisit); /*! * \brief A mutator works in unnormalized form. * - * ExprMutatorBase expects input AST to be in the unnormalized form, i.e., checked_type_ and shape_ + * ExprMutatorBase expects input AST to be in the unnormalized form, i.e., struct_info_ * of expressions can be nullptr, and the expressions may nest(and as a result the AST is not in * ANF). */ @@ -414,7 +414,7 @@ class ExprMutatorBase : public ExprFunctor { * \brief A mutator works in normal form. * * ExprMutator expects input AST to be in the normal form, i.e., the expressions are normalized(no - * nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are + * nesting and hence the AST is in ANF), and all struct_info_ of expressions are * available. */ class ExprMutator : public ExprMutatorBase { diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index b8ff0fa59dfe..6ccd693bff02 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -149,7 +149,7 @@ TVM_DLL Pass AttachGlobalSymbol(); /*! * \brief Transform Relax IR to normal form: transform AST to A-normal form, and fill the - * checked_type_ and shape_ of expressions. + * struct_info_ of expressions. * * \return The Pass. */ diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index edb88bcafe55..92fe19e8aa35 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -107,7 +107,6 @@ class PrimFuncNode : public BaseFuncNode { v->Visit("buffer_map", &buffer_map); v->Visit("attrs", &attrs); v->Visit("span", &span); - v->Visit("_checked_type_", &checked_type_); } bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 197b0831bf25..1d5389827f8e 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -23,7 +23,6 @@ from ..runtime import Object, Scriptable from . import _ffi_api from .base import Node, Span -from .type import Type class BaseExpr(Node): @@ -45,20 +44,6 @@ class PrimExpr(BaseExpr): class RelaxExpr(BaseExpr): """Base class of all non-primitive expressions.""" - @property - def checked_type(self): - """Get the checked type of tvm.relax.Expr. - - Returns - ------- - checked_type : tvm.ir.Type - The checked type. - """ - ret = self._checked_type_ - if ret is None: - raise ValueError("The type checker has not populated the checked_type for this node") - return ret - @property def struct_info(self) -> Optional["tvm.relax.StructInfo"]: """Get the struct info field @@ -86,8 +71,8 @@ class GlobalVar(RelaxExpr): name_hint: str - def __init__(self, name_hint: str, type_annot: Optional[Type] = None): - self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, type_annot) + def __init__(self, name_hint: str): + self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint) def __call__(self, *args: RelaxExpr) -> BaseExpr: """Call the global variable. diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 42486ee28948..34aab9c99e77 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -125,22 +125,6 @@ def has_attr(self, attrs: Dict[str, Object]) -> "AttrPattern": def has_struct_info(self, struct_info: "StructInfo") -> "StructInfoPattern": return StructInfoPattern(self, struct_info) - def has_type(self, ttype: tvm.ir.type.Type) -> "TypePattern": - """ - Add a type constraint to this pattern - - Parameters - ---------- - ttype: tvm.ir.type.Type - The type to match - - Returns - ------- - result: TypePattern - The resulting TypePattern - """ - return TypePattern(self, ttype) - def has_dtype(self, dtype: str) -> "DataTypePattern": """ Add a type constraint to this pattern @@ -598,23 +582,6 @@ def __init__(self, pattern: "DFPattern", struct_info: "StructInfo"): ) # type: ignore -@register_df_node -class TypePattern(DFPattern): - """A pattern that matches another pattern with a certain type annotation. - - Parameters - ---------- - pattern: tvm.relax.dpl.DFPattern - The input pattern that needs type annotation. - - ttype: tvm.ir.type.Type - The type to match. - """ - - def __init__(self, pattern: "DFPattern", ttype: tvm.ir.type.Type): - self.__init_handle_by_constructor__(ffi.TypePattern, pattern, ttype) # type: ignore - - @register_df_node class DataTypePattern(DFPattern): """A pattern that matches another pattern with certain data type diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index bae308f435ce..201e99e3d10f 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -913,7 +913,7 @@ def _impl_v13(cls, bb, inputs, attr, params): A = inputs[0] B = inputs[1] C = inputs[2] - dtype = A.checked_type.dtype + dtype = A.struct_info.dtype # Compute Y = alpha * A X B + beta * C @@ -1083,7 +1083,7 @@ class Mish(OnnxOpConverter): @classmethod def _impl_v18(cls, bb, inputs, attr, params): - dtype = inputs[0].checked_type.dtype + dtype = inputs[0].struct_info.dtype return inputs[0] * relax.op.tanh( relax.op.log(relax.const(1.0, dtype) + relax.op.exp(inputs[0])) ) @@ -1670,7 +1670,7 @@ def _check_type(cls, dtype, valid_types): def _impl_v1(cls, bb, inputs, attr, params): data = inputs[0] valid_types = ["float", "float32", "double", "float64", "float16"] - cls._check_type(data.checked_type.dtype, valid_types) + cls._check_type(data.struct_info.dtype, valid_types) return relax.op.exp(data) @@ -1678,7 +1678,7 @@ def _impl_v1(cls, bb, inputs, attr, params): def _impl_v13(cls, bb, inputs, attr, params): data = inputs[0] valid_types = ["float", "float32", "double", "float64", "float16", "bfloat16"] - cls._check_type(data.checked_type.dtype, valid_types) + cls._check_type(data.struct_info.dtype, valid_types) return relax.op.exp(data) @@ -1723,7 +1723,7 @@ def _impl_v13(cls, bb, inputs, attr, params): splits = inputs[1] splits_rank = None if splits is not None: - splits_rank = splits.checked_type.ndim + splits_rank = splits.struct_info.ndim if splits is not None and splits_rank > 0: if isinstance(splits, relax.Constant): splits = splits.data.numpy() @@ -3508,11 +3508,11 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): if op_name in return_tuple_ops: outputs_num = 1 elif not isinstance(op, relax.Tuple): - if isinstance(op.checked_type, tvm.ir.type.TupleType): + if isinstance(op.struct_info, relax.TupleStructInfo): # This is a var bound to a tuple. We need to unpack it and create # a new tuple. tuple_items = [] - for i in range(len(op.checked_type.fields)): + for i in range(len(op.struct_info.fields)): tuple_items.append(self.bb.emit(relax.TupleGetItem(op, i))) op = relax.Tuple(tuple_items) outputs_num = len(tuple_items) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 33abccbe5f85..8edf131aaa96 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -409,8 +409,8 @@ def _group_norm_module(self, node: fx.Node) -> relax.Var: gamma = self.params[module.weight] beta = self.params[module.bias] else: - gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) - beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) + gamma = relax.const(torch.ones_like(module.num_channels), x.struct_info.dtype) + beta = relax.const(torch.zeros_like(module.num_channels), x.struct_info.dtype) eps = module.eps dim = len(self.shape_of(x)) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index c71b19494a41..864eb3fec709 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -143,7 +143,7 @@ def layout_transform( if callable(index_map): index_map = IndexMap.from_func(index_map, index_dtype=default_index_dtype) - x_dtype = x.checked_type.dtype + x_dtype = x.struct_info.dtype # Explicitly convert python int/float pad_value to the x's type. If the default behavior # is applied, it would be converted to int32/float32, which may not match the x's type. diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index 40e617b38f1a..ca50d39f7bd5 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -45,11 +45,9 @@ def __init__( self, indent_str=" ", include_struct_info_annotations=True, - include_type_annotations=False, include_call_attrs=True, ): self.indent_str = indent_str - self.include_type_annotations = include_type_annotations self.include_struct_info_annotations = include_struct_info_annotations self.include_call_attrs = include_call_attrs @@ -88,13 +86,11 @@ def build_ast_node(self, nodename: str, force_newline=False, **kwargs: str) -> s def build_expr(self, node: relax.Expr, nodename: str, force_newline=False, **kwargs: str): """ Renders a Relax expression as a string using `build_ast_node`. - Handles whether to include the checked_type_ and struct_info fields. + Handles whether to include the struct_info fields. """ fields = kwargs.copy() if node.struct_info_ and self.include_struct_info_annotations: fields["struct_info"] = self.visit_struct_info_(node.struct_info) - if node._checked_type_ and self.include_type_annotations: - fields["checked_type_"] = self.visit_type_(node.checked_type) return self.build_ast_node(nodename, force_newline=force_newline, **fields) def build_list( @@ -136,7 +132,7 @@ def visit_shape_expr_(self, op: relax.ShapeExpr) -> str: def visit_extern_func_(self, op: relax.ExternFunc) -> str: # ExternFunc does not inherit from relax.Expr either, - # so it doesn't have checked_type_ or struct_info fields and we don't use build_expr + # so it doesn't have struct_info fields and we don't use build_expr return self.build_ast_node("ExternFunc", global_symbol=wrap_quotes(op.global_symbol)) def visit_global_var_(self, op: relax.GlobalVar) -> str: @@ -220,8 +216,8 @@ def visit_data_type_imm_(self, op: relax.DataTypeImm) -> str: def visit_op_(self, op: tvm.ir.Op) -> str: # TODO: List other attributes? - # op is not actually a Relax expr and does not have checked_type_ - # or struct_info fields, so we don't use build_expr here + # op is not actually a Relax expr and does not have + # struct_info fields, so we don't use build_expr here return self.build_ast_node("Op", name=wrap_quotes(op.name)) def visit_prim_expr_(self, prim_expr: PrimExpr) -> str: @@ -365,7 +361,6 @@ def dump_ast( exp: relax.Expr, indent_str=" ", include_struct_info_annotations=True, - include_type_annotations=False, include_call_attrs=True, ) -> str: """ @@ -376,7 +371,6 @@ def dump_ast( printer = ASTPrinter( indent_str=indent_str, include_struct_info_annotations=include_struct_info_annotations, - include_type_annotations=include_type_annotations, include_call_attrs=include_call_attrs, ) return printer.visit_expr(exp) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 8e74f0897720..57627ceebe66 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -417,7 +417,7 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass: def Normalize() -> tvm.ir.transform.Pass: """Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting - and hence the AST is in ANF), and all ``checked_type_`` and ``shape_`` of expressions are + and hence the AST is in ANF), and all `struct_info_` of expressions are available. Returns diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 97eb3d5b2fe1..155c7e10de60 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -432,10 +432,10 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args): assert isinstance(global_var, tvm.ir.GlobalVar) dtype = "void" - if global_var.checked_type is not None: - ret_type = global_var.checked_type.ret_type - if hasattr(ret_type, "dtype"): - dtype = ret_type.dtype + if global_var.struct_info is not None: + ret_sinfo = global_var.struct_info.ret + if hasattr(ret_sinfo, "dtype"): + dtype = ret_sinfo.dtype return Call(dtype=dtype, op=global_var, args=args) diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc index 877530f4c378..1b6a2896cccb 100644 --- a/src/ir/apply_pass_to_function.cc +++ b/src/ir/apply_pass_to_function.cc @@ -88,7 +88,6 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex, keep_original_version.insert(gvar->name_hint); func = relax::ExternFunc("dummy_" + name); func->struct_info_ = gvar->struct_info_; - func->checked_type_ = gvar->checked_type_; } subset->Add(gvar, func); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index fcfd8deeb11f..b45bcd968421 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -192,19 +192,16 @@ TVM_FFI_REGISTER_GLOBAL("ir.Range") TVM_REGISTER_NODE_TYPE(RangeNode); -GlobalVar::GlobalVar(String name_hint, Type type, Span span) { +GlobalVar::GlobalVar(String name_hint, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); - n->checked_type_ = std::move(type); n->span = std::move(span); data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_FFI_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) { - return GlobalVar(name, type); -}); +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) { return GlobalVar(name); }); TVM_FFI_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { std::stringstream ss; diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 243033e9454b..315281cc007f 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -55,7 +55,7 @@ * * The cond field of If nodes * * The op or args fields of Call nodes * * Inside the fields of Tuple nodes - * 13. Expr always has checked_type_ (with the exception of Op). + * 13. Expr always has struct_info_ (with the exception of Op). * 14. DataflowBlocks may not contain If nodes. * 15. DataflowBlocks may not contain calls to impure functions or operators * (only checked if check_struct_info is true). @@ -147,8 +147,8 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitExpr(const Expr& expr) final { - if (!expr.as() && !expr->checked_type_.defined()) { - Malformed(Diagnostic::Error(expr) << "The checked_type_ of Expr " << expr << " is nullptr."); + if (!expr.as() && !expr->struct_info_.defined()) { + Malformed(Diagnostic::Error(expr) << "The struct_info_ of Expr " << expr << " is nullptr."); } relax::ExprVisitor::VisitExpr(expr); } @@ -162,11 +162,10 @@ class WellFormedChecker : public relax::ExprVisitor, } } - if (op->checked_type_.defined()) { - if ((!op->checked_type_->IsInstance()) && - (!op->checked_type_->IsInstance())) { - Malformed(Diagnostic::Error(var) << "The checked_type_ of GlobalVar " << GetRef(op) - << " must be either FuncType or PackedFuncType."); + if (op->struct_info_.defined()) { + if (!op->struct_info_->IsInstance()) { + Malformed(Diagnostic::Error(var) << "The struct_info_ of GlobalVar " << GetRef(op) + << " must be either FuncStructInfo."); } } diff --git a/src/relax/backend/contrib/codegen_c/codegen_c.h b/src/relax/backend/contrib/codegen_c/codegen_c.h index 28ca0e3586e8..795b691dec4c 100644 --- a/src/relax/backend/contrib/codegen_c/codegen_c.h +++ b/src/relax/backend/contrib/codegen_c/codegen_c.h @@ -336,9 +336,9 @@ class CodegenCBase { * \return The dtype string. */ std::string GetDtypeString(const Var& var) { - auto ttype = var->checked_type().as(); - ICHECK(ttype) << "Expect TensorTypeNode"; - return GetDtypeString(ttype); + auto tsinfo = var->struct_info_.as(); + ICHECK(tsinfo) << "Expect TensorStructInfoNode"; + return GetDtypeString(tsinfo); } /*! @@ -348,24 +348,24 @@ class CodegenCBase { * * \return The dtype string. */ - std::string GetDtypeString(const TensorTypeNode* ttype) { + std::string GetDtypeString(const TensorStructInfoNode* tsinfo) { std::string dtype; - if (runtime::TypeMatch(ttype->dtype, kDLFloat, 32)) { + if (runtime::TypeMatch(tsinfo->dtype, kDLFloat, 32)) { dtype = "float"; - } else if (runtime::TypeMatch(ttype->dtype, kDLFloat, 16)) { + } else if (runtime::TypeMatch(tsinfo->dtype, kDLFloat, 16)) { dtype = "half"; - } else if (runtime::TypeMatch(ttype->dtype, kDLBfloat, 16)) { + } else if (runtime::TypeMatch(tsinfo->dtype, kDLBfloat, 16)) { dtype = "bfloat"; - } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) { + } else if (runtime::TypeMatch(tsinfo->dtype, kDLInt, 32)) { dtype = "int"; - } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) { + } else if (runtime::TypeMatch(tsinfo->dtype, kDLInt, 64)) { dtype = "int64_t"; - } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 8)) { + } else if (runtime::TypeMatch(tsinfo->dtype, kDLInt, 8)) { dtype = "int8_t"; - } else if (runtime::TypeMatch(ttype->dtype, kDLUInt, 8)) { + } else if (runtime::TypeMatch(tsinfo->dtype, kDLUInt, 8)) { dtype = "uint8_t"; } else { - LOG(FATAL) << "Unsupported dtype " << ttype->dtype; + LOG(FATAL) << "Unsupported dtype " << tsinfo->dtype; } return dtype; diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 63288201e741..8f4aee382be4 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -95,7 +95,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // TODO(relax-team): add fine-grained PrimFunc struct info signature generation. finfo = FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)); } else { - finfo = StructInfoFromType(func->checked_type()); + TVM_FFI_THROW(RuntimeError) << "Expect struct_info field to be populated"; } UpdateStructInfo(gvar, finfo); diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 8a4ca3f7ba0a..dd6c4f3d40cb 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -447,12 +447,6 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { return analyzer_.Simplify(sorted_condition); } -bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr0) { - auto expr = UnwrapBindings(expr0, var2val_); - auto expr_type = expr.as()->checked_type(); - return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); -} - static bool ShapeEqual(Analyzer* analyzer, const Array& lhs, const Array& rhs) { if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) @@ -543,9 +537,9 @@ bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const Expr& bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) { // no need to jump, as var.dtype == value.dtype - auto expr_type = expr.as()->checked_type(); - if (const TensorTypeNode* tensor_type = expr_type.as()) { - return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr); + auto expr_sinfo = expr.as()->struct_info_; + if (const TensorStructInfoNode* tensor_sinfo = expr_sinfo.as()) { + return (StructuralEqual()(op->dtype, tensor_sinfo->dtype)) && VisitDFPattern(op->pattern, expr); } return false; } diff --git a/src/relax/ir/dataflow_matcher.h b/src/relax/ir/dataflow_matcher.h index 76f48383c47c..71fa4a4c35c1 100644 --- a/src/relax/ir/dataflow_matcher.h +++ b/src/relax/ir/dataflow_matcher.h @@ -63,7 +63,6 @@ class DFPatternMatcher : public DFPatternFunctorstream << "*"; }); -TVM_REGISTER_NODE_TYPE(TypePatternNode); -TypePattern::TypePattern(DFPattern pattern, Type type) { - ObjectPtr n = make_object(); - n->pattern = std::move(pattern); - n->type = std::move(type); - data_ = std::move(n); -} -TVM_FFI_REGISTER_GLOBAL("relax.dpl.TypePattern").set_body_typed([](DFPattern pattern, Type type) { - return TypePattern(pattern, type); -}); -RELAX_PATTERN_PRINTER_DEF(TypePatternNode, [](auto p, auto node) { - p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")"; -}); - TVM_REGISTER_NODE_TYPE(StructInfoPatternNode); StructInfoPattern::StructInfoPattern(DFPattern pattern, StructInfo struct_info) { ObjectPtr n = make_object(); @@ -391,9 +377,7 @@ class DFPatternDuplicator : public DFPatternFunctor DFPattern VisitDFPattern_(const StructInfoPatternNode* op) override { return StructInfoPattern(op->pattern, op->struct_info); } - DFPattern VisitDFPattern_(const TypePatternNode* op) override { - return TypePattern(op->pattern, op->type); - } + DFPattern VisitDFPattern_(const DataflowVarPatternNode* op) override { return DataflowVarPattern(op->name); } @@ -421,7 +405,6 @@ AttrPattern DFPattern::HasAttr(const Map& attrs) const { StructInfoPattern DFPattern::HasStructInfo(const StructInfo& struct_info) const { return StructInfoPattern(*this, struct_info); } -TypePattern DFPattern::HasType(const Type& type) const { return TypePattern(*this, type); } DataTypePattern DFPattern::HasDtype(const DataType& dtype) const { return DataTypePattern(*this, dtype); } diff --git a/src/relax/ir/dataflow_pattern_functor.cc b/src/relax/ir/dataflow_pattern_functor.cc index 2a0c73501850..7179d6bc83a1 100644 --- a/src/relax/ir/dataflow_pattern_functor.cc +++ b/src/relax/ir/dataflow_pattern_functor.cc @@ -96,8 +96,6 @@ void DFPatternVisitor::VisitDFPattern_(const UnorderedTuplePatternNode* op) { } } -void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); } - void DFPatternVisitor::VisitDFPattern_(const StructInfoPatternNode* op) { VisitDFPattern(op->pattern); } diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 238cece41f61..c9d83e92389e 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -153,9 +153,6 @@ Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); n->span = std::move(span); - if (tuple_sinfo) { - n->checked_type_ = GetStaticType(tuple_sinfo.value()); - } n->struct_info_ = tuple_sinfo; data_ = std::move(n); } @@ -199,7 +196,6 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { << ", and cannot be accessed with index " << index; auto sinfo = tuple_info->fields[index]; n->struct_info_ = sinfo; - n->checked_type_ = GetStaticType(sinfo); } n->tuple = std::move(tuple); n->index = index; @@ -244,7 +240,6 @@ ShapeExpr::ShapeExpr(Array values, Span span) { return value; }); n->span = span; - n->checked_type_ = ShapeType(values.size()); n->struct_info_ = ShapeStructInfo(values, span); data_ = std::move(n); } @@ -258,9 +253,6 @@ TVM_REGISTER_NODE_TYPE(VarNode); Var::Var(Id vid, Optional struct_info_annotation, Span span) { ObjectPtr n = make_object(); n->vid = std::move(vid); - if (struct_info_annotation) { - n->checked_type_ = GetStaticType(struct_info_annotation.value()); - } n->struct_info_ = std::move(struct_info_annotation); n->span = std::move(span); data_ = std::move(n); @@ -300,9 +292,6 @@ TVM_REGISTER_NODE_TYPE(DataflowVarNode); DataflowVar::DataflowVar(Id vid, Optional struct_info_annotation, Span span) { ObjectPtr n = make_object(); n->vid = std::move(vid); - if (struct_info_annotation) { - n->checked_type_ = GetStaticType(struct_info_annotation.value()); - } n->struct_info_ = std::move(struct_info_annotation); n->span = std::move(span); n->span = std::move(span); @@ -332,11 +321,9 @@ Constant::Constant(runtime::NDArray data, Optional struct_info_annot } if (struct_info_annotation.defined()) { n->struct_info_ = struct_info_annotation.value(); - n->checked_type_ = GetStaticType(struct_info_annotation.value()); } else { TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), VDevice(), span); n->struct_info_ = tinfo; - n->checked_type_ = TensorType(tinfo->ndim, tinfo->dtype); } data_ = std::move(n); @@ -353,7 +340,6 @@ TVM_FFI_REGISTER_GLOBAL("relax.Constant") PrimValue::PrimValue(PrimExpr value, Span span) { ObjectPtr n = make_object(); - n->checked_type_ = PrimType(value.dtype()); n->struct_info_ = PrimStructInfo(value); n->value = std::move(value); n->span = std::move(span); @@ -374,9 +360,6 @@ StringImm::StringImm(String value, Span span) { ObjectPtr n = make_object(); n->value = std::move(value); n->span = std::move(span); - // use the base structinfo for now - // we can choose to introduce more fine-grained struct info later if necessary. - n->checked_type_ = ObjectType(); n->struct_info_ = ObjectStructInfo(); data_ = std::move(n); } @@ -391,9 +374,6 @@ DataTypeImm::DataTypeImm(DataType value, Span span) { ObjectPtr n = make_object(); n->value = std::move(value); n->span = std::move(span); - // use the base structinfo for now - // we can choose to introduce more fine-grained struct info later if necessary. - n->checked_type_ = ObjectType(); n->struct_info_ = ObjectStructInfo(); data_ = std::move(n); } @@ -619,7 +599,6 @@ Function::Function(Array params, Expr body, Optional ret_struct n->body = std::move(body); n->ret_struct_info = std::move(ret_struct_info.value()); n->is_pure = is_pure; - n->checked_type_ = GetStaticType(func_sinfo); n->struct_info_ = std::move(func_sinfo); n->attrs = std::move(attrs); n->span = std::move(span); @@ -636,8 +615,8 @@ Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bo DictAttrs attrs, Span span) { Array param_sinfo; for (const Var& param : params) { - ICHECK(param->checked_type_.defined()) - << "relax.Function requires params to contain checked_type_."; + ICHECK(param->struct_info_.defined()) + << "relax.Function requires params to contain struct_info_."; param_sinfo.push_back(GetStructInfo(param)); } @@ -656,7 +635,6 @@ Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bo n->params = std::move(params); n->body = std::move(body); n->is_pure = is_pure; - n->checked_type_ = GetStaticType(finfo); n->struct_info_ = std::move(finfo); n->ret_struct_info = std::move(ret_struct_info); n->attrs = std::move(attrs); @@ -706,7 +684,6 @@ ExternFunc::ExternFunc(String global_symbol, StructInfo struct_info, Span span) n->global_symbol = std::move(global_symbol); n->span = span; n->struct_info_ = struct_info; - n->checked_type_ = GetStaticType(struct_info); data_ = std::move(n); } diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index feb1f910a42c..9da3de96b325 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -214,8 +214,6 @@ void UpdateStructInfo(Expr expr, StructInfo struct_info) { << "However, expression " << expr << " has struct info " << expr->struct_info_ << ", which cannot be overwritten with " << struct_info; expr->struct_info_ = struct_info; - // also set checked type - expr->checked_type_ = GetStaticType(struct_info); } TVM_FFI_REGISTER_GLOBAL("relax.UpdateStructInfo") diff --git a/src/relax/transform/expand_tuple_arguments.cc b/src/relax/transform/expand_tuple_arguments.cc index 1a9afadf7e48..ec5818476a57 100644 --- a/src/relax/transform/expand_tuple_arguments.cc +++ b/src/relax/transform/expand_tuple_arguments.cc @@ -131,7 +131,7 @@ Pass ExpandTupleArguments() { if (auto func = base_func.as()) { if (auto opt = ExpandParams(func.value())) { auto new_func = opt.value(); - GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_); + GlobalVar new_gvar(gvar->name_hint); new_gvar->struct_info_ = new_func->struct_info_; gvar_replacements[gvar] = new_gvar; new_callees[new_gvar] = new_func; diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 7fec51086514..489b17f15c32 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -185,9 +185,9 @@ class ConstantFolder : public ExprMutator { bool output_not_tuple = call->sinfo_args.size() == 1; // Pattern 0: call constant function, const argument with const shape. if (func && arr_args && shape && output_not_tuple) { - TensorType ret_type = Downcast(call->checked_type()); + TensorStructInfo ret_sinfo = Downcast(call->struct_info_); // value_or will return value if it is not null, otherwise return or - return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_type->dtype) + return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_sinfo->dtype) .value_or({}); } // TODO(hongyi): support const-fold tuple outputs diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 05b7bf4218dd..90cb6a00fcfd 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -684,11 +684,12 @@ class FusedTIRConstructor : public ExprVisitor { if (it != func_info_.expr2buffers.end()) { int begin_buf_idx = 0; int end_buf_idx = 0; - const TupleType& tuple_type = Downcast(tuple_get_item->tuple->checked_type()); + const TupleStructInfo& tuple_sinfo = + Downcast(tuple_get_item->tuple->struct_info_); for (int i = 0; i < tuple_get_item->index; ++i) { - begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]); + begin_buf_idx += GetTotalTensorSize(tuple_sinfo->fields[i]); } - end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]); + end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_sinfo->fields[tuple_get_item->index]); func_info_.expr2buffers.Set( GetRef(tuple_get_item), {(*it).second.begin() + begin_buf_idx, (*it).second.begin() + end_buf_idx}); @@ -972,17 +973,17 @@ class FusedTIRConstructor : public ExprVisitor { } /*! \brief Get DynTensor numbers from recursive Tuples. */ - static size_t GetTotalTensorSize(const Type& type) { - if (type.as()) { + static size_t GetTotalTensorSize(const StructInfo& sinfo) { + if (sinfo.as()) { return 1; - } else if (const auto* tuple_type = type.as()) { + } else if (const auto* tuple_sinfo = sinfo.as()) { size_t num = 0; - for (const Type& type : tuple_type->fields) { - num += GetTotalTensorSize(type); + for (const StructInfo& sinfo : tuple_sinfo->fields) { + num += GetTotalTensorSize(sinfo); } return num; } else { - LOG(FATAL) << "TensorType and TupleType are expect, but got: " << type; + LOG(FATAL) << "TensorType and TupleType are expect, but got: " << sinfo; return 0; } } diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index e5e28cb55375..fb0a6a6ee16a 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -328,13 +328,6 @@ class LambdaLifter : public ExprMutator { Function(lifted_func_params, body, ret_struct_info, func_node->is_pure, func_node->attrs); } - for (Var param : lifted_func->params) { - CHECK(param->checked_type_.defined()) - << "relax.Function requires all parameters to contain checked_type_. " - << "However, parameter " << param << " with struct info " << param->struct_info_ - << " has no checked type"; - } - ICHECK(lifted_func.defined()); if (is_closure || IsClosure(lifted_func)) { @@ -344,7 +337,6 @@ class LambdaLifter : public ExprMutator { // Add the lifted function to the module. lifted_func = CopyWithNewVars(lifted_func); gvar_lifted_func->struct_info_ = GetStructInfo(lifted_func); - gvar_lifted_func->checked_type_ = lifted_func->checked_type_; builder_->UpdateFunction(gvar_lifted_func, lifted_func); diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 07ca6a1133e7..d0c41ff77a37 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -20,7 +20,7 @@ /*! * \file tvm/relax/transform/normalize.cc * \brief Pass for transforming Relax IR to normal form, i.e., the expressions are normalized(no - * nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are + * nesting and hence the AST is in ANF), and all struct_info_ of expressions are * available. */ diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index e170588f60c6..ea8a8fa14f29 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -240,7 +240,7 @@ Pass RemoveUnusedOutputs() { const auto& usage_mask = it->second; auto new_func = UpdateCallee(func.value(), usage_mask); - GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_); + GlobalVar new_gvar(gvar->name_hint); new_gvar->struct_info_ = new_func->struct_info_; new_callees->Add(new_gvar, new_func); diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index 911e427935be..5018232668b9 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -194,7 +194,7 @@ Pass RemoveUnusedParameters() { if (auto func = base_func.as()) { if (auto callee_res = AnalyzeCallee(func.value())) { auto new_func = callee_res->func; - GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_); + GlobalVar new_gvar(gvar->name_hint); new_gvar->struct_info_ = new_func->struct_info_; new_callees->Add(new_gvar, new_func); diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 270f4623ef0c..111f2beae328 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -62,15 +62,12 @@ GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) return {}; }(); - GlobalVar gv = GlobalVar(func_name, gvar_type); + GlobalVar gv = GlobalVar(func_name); gv->struct_info_ = GetGlobalVarStructInfo(func_signature); CHECK(frame->functions.find(gv) == frame->functions.end()) << "ValueError: function " << func_name << " has already been defined."; frame->global_var_map.Set(func_name, gv); frame->functions.Set(gv, func_signature); - ICHECK(func_signature->checked_type_.defined()) - << "The checked_type_ of function signature must be defined."; - gv->checked_type_ = func_signature->checked_type_; return gv; } @@ -81,11 +78,7 @@ void DefFunction(const String& func_name, const BaseFunc& func) { << "ValueError: function " << func_name << " does not exist, please declare it first."; const GlobalVar& gv = (*it).second; frame->functions.Set(gv, func); - CHECK(func->checked_type_.defined()) - << "The checked_type_ of function must be defined, but it is not defined for function `" - << func_name << "`."; gv->struct_info_ = GetGlobalVarStructInfo(func); - gv->checked_type_ = func->checked_type_; } void ModuleAttrs(Map attrs, bool allow_overwrite) { diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 2312d31fd276..9eb9bc26e343 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -84,7 +84,6 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, n->ret_type = std::move(ret_type); n->buffer_map = std::move(buffer_map); n->attrs = std::move(attrs); - n->checked_type_ = n->func_type_annotation(); n->struct_info_ = relax::FuncStructInfo::OpaqueFunc(); n->span = std::move(span); data_ = std::move(n); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 250f51d09c57..48cf7bbad1f5 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -389,7 +389,6 @@ PrimFunc MakePackedAPI(PrimFunc func) { << " are used, but are not passed in as API arguments"; func_ptr->buffer_map = Map(); - func_ptr->checked_type_ = func_ptr->func_type_annotation(); func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function. diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index d9eefcfd0ef2..fd7c70c6148c 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -281,7 +281,7 @@ def test_if_non_seq_body(): ] new_func = build_function(new_blocks) new_mod = tvm.IRModule.from_expr(new_func) - # apply normalization to fill in checked_type_ + # apply normalization to fill in struct_info_ normalized = rx.transform.Normalize()(new_mod) assert rx.analysis.well_formed(normalized, check_struct_info=True) @@ -320,7 +320,7 @@ def test_if_complex_condition(): ] func = build_function(blocks) mod = tvm.IRModule.from_expr(func) - # apply normalization to fill in checked_type_ + # apply normalization to fill in struct_info_ normalized = rx.transform.Normalize()(mod) assert rx.analysis.well_formed(normalized, check_struct_info=True) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 6f55c542ea05..6a65b1b751c7 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -29,7 +29,7 @@ from tvm.script import tir as T # Overload dump_ast to test both struct info and type annotations -dump_ast = partial(dump_ast, include_struct_info_annotations=True, include_type_annotations=True) +dump_ast = partial(dump_ast, include_struct_info_annotations=True) def strip_whitespace(text: str) -> str: @@ -41,7 +41,7 @@ def strip_whitespace(text: str) -> str: def normalize(func: rx.Function) -> rx.Function: """ - Normalize the expr to fill in the checked_type_ and struct_info fields everywhere + Normalize the expr to fill in the struct_info fields everywhere """ # using a default mutator to use the BlockBuilder's normalizer, @@ -79,15 +79,12 @@ def test_var() -> None: assert v0_str == 'Var(name_hint="v0")' v1 = rx.Var("v1", R.Tensor([54, 96], "float32")) - v1_no_annos = dump_ast( - v1, include_struct_info_annotations=False, include_type_annotations=False - ) + v1_no_annos = dump_ast(v1, include_struct_info_annotations=False) assert v1_no_annos == 'Var(name_hint="v1")' v1_annos = dump_ast(v1) assert v1_annos != v1_no_annos assert "PrimExpr" in v1_annos assert "struct_info" in v1_annos - assert "checked_type_" in v1_annos def test_dataflow_var() -> None: @@ -96,15 +93,12 @@ def test_dataflow_var() -> None: assert v0_str == 'DataflowVar(name_hint="v0")' v1 = rx.DataflowVar("v1", R.Tensor([54, 96], "float16")) - v1_no_annos = dump_ast( - v1, include_struct_info_annotations=False, include_type_annotations=False - ) + v1_no_annos = dump_ast(v1, include_struct_info_annotations=False) assert v1_no_annos == 'DataflowVar(name_hint="v1")' v1_annos = dump_ast(v1) assert v1_annos != v1_no_annos assert "PrimExpr" in v1_annos assert "struct_info" in v1_annos - assert "checked_type_" in v1_annos def test_match_cast() -> None: @@ -121,7 +115,6 @@ def test_match_cast() -> None: assert "PrimExpr(value=`n" in b0_str assert "16" in b0_str assert "8" in b0_str - assert b0_str != dump_ast(b0, include_type_annotations=False) # var1: Tensor((m, n), "float32") = # match_cast(var0: R.Tensor("float32"), [m, n]) @@ -132,16 +125,14 @@ def test_match_cast() -> None: assert b1_str.startswith("MatchCast(") assert "PrimExpr(value=`m" in b1_str assert "PrimExpr(value=`n" in b1_str - assert b1_str != dump_ast( - b1, include_type_annotations=False, include_struct_info_annotations=False - ) + assert b1_str != dump_ast(b1, include_struct_info_annotations=False) def test_var_binding() -> None: v0 = rx.Var("v0") val = rx.const(np.random.rand(24, 56)) b0 = rx.VarBinding(v0, val) - b0_str = dump_ast(b0, include_type_annotations=False, include_struct_info_annotations=False) + b0_str = dump_ast(b0, include_struct_info_annotations=False) assert b0_str.startswith("VarBinding(") assert 'var=Var(name_hint="v0")' in b0_str assert "value=" in b0_str @@ -232,7 +223,6 @@ def test_func(): assert "SeqExpr(" in func_str assert "blocks=" in func_str assert "VarBinding(" in func_str - assert func_str != dump_ast(func, include_type_annotations=False) def test_shape_of(): @@ -290,7 +280,7 @@ def test_types(): def test_struct_info(): - printer = ASTPrinter(include_type_annotations=True) + printer = ASTPrinter() assert printer.visit_struct_info_(rx.ObjectStructInfo()) == "ObjectStructInfo()" @@ -330,8 +320,7 @@ def test_struct_info(): dtype=int32, shape=Var( name_hint="x", - struct_info=ShapeStructInfo(ndim=0, values=[]), - checked_type_=ShapeType(ndim=0) + struct_info=ShapeStructInfo(ndim=0, values=[]) ) ) """ @@ -379,7 +368,6 @@ def f( f_str = strip_whitespace( dump_ast( f, - include_type_annotations=False, include_struct_info_annotations=False, include_call_attrs=True, ) @@ -394,7 +382,6 @@ def f( extern_call = f.body.blocks[0].bindings[-1].value extern_call_text = dump_ast( extern_call, - include_type_annotations=False, include_struct_info_annotations=False, include_call_attrs=True, ) @@ -414,7 +401,6 @@ def f( op_call = f.body.blocks[0].bindings[0].value op_call_text = dump_ast( op_call, - include_type_annotations=False, include_struct_info_annotations=False, include_call_attrs=True, ) @@ -459,7 +445,6 @@ def foo(x: R.Tensor(("m", "n"), "float32")): foo_str = strip_whitespace( dump_ast( foo, - include_type_annotations=False, include_struct_info_annotations=False, include_call_attrs=False, ) @@ -471,7 +456,6 @@ def foo(x: R.Tensor(("m", "n"), "float32")): tir_call = foo.body.blocks[0].bindings[0].value tir_call_text = dump_ast( tir_call, - include_type_annotations=False, include_struct_info_annotations=False, include_call_attrs=False, ) @@ -510,7 +494,6 @@ def foo(x: R.Tensor(("m", "n"), "float32")): foo_str = strip_whitespace( dump_ast( foo, - include_type_annotations=False, include_struct_info_annotations=False, include_call_attrs=False, ) @@ -522,7 +505,6 @@ def foo(x: R.Tensor(("m", "n"), "float32")): tir_call = foo.body.blocks[0].bindings[0].value tir_call_text = dump_ast( tir_call, - include_type_annotations=False, include_struct_info_annotations=False, include_call_attrs=False, ) @@ -559,7 +541,6 @@ def foo(x: R.Tensor): foo_str = strip_whitespace( dump_ast( foo, - include_type_annotations=False, include_struct_info_annotations=False, ) ) @@ -576,7 +557,6 @@ def bar(x: R.Tensor): bar_str = strip_whitespace( dump_ast( bar, - include_type_annotations=False, include_struct_info_annotations=False, ) ) @@ -601,8 +581,7 @@ def f() -> R.Tensor: struct_info=ShapeStructInfo( ndim=1, values=[PrimExpr(value=`T.int64(2)`)] - ), - checked_type_=ShapeType(ndim=1) + ) ) ) """ @@ -619,15 +598,6 @@ def f() -> R.Shape: assert isinstance(body, rx.SeqExpr) call = body.blocks[-1].bindings[-1].value assert isinstance(call, rx.Call) - arg = call.args[0] - arg_str = strip_whitespace(dump_ast(arg)) - # the constant should have a tensor type - assert "checked_type_=TensorType(ndim=0" in arg_str - - call_str = strip_whitespace(dump_ast(call)) - # we expect the shape_of call to have a checked_type_ of ShapeType - type_str = "checked_type_=ShapeType(ndim=0)" - assert type_str in call_str def test_if(): @@ -669,8 +639,7 @@ def test_prim_value(): """ PrimValue( value=PrimExpr(value=`T.int64(1)`), - struct_info=PrimStructInfo(dtype=int64), - checked_type_=PrimType(dtype=int64) + struct_info=PrimStructInfo(dtype=int64) ) """ ) @@ -683,8 +652,7 @@ def test_string_imm(): """ StringImm( value="test", - struct_info=ObjectStructInfo(), - checked_type_=ObjectType() + struct_info=ObjectStructInfo() ) """ ) @@ -697,8 +665,7 @@ def test_datatype_imm(): """ DataTypeImm( value=int32, - struct_info=ObjectStructInfo(), - checked_type_=ObjectType() + struct_info=ObjectStructInfo() ) """ ) diff --git a/tests/python/relax/test_blockbuilder_core.py b/tests/python/relax/test_blockbuilder_core.py index efab59a0e683..be60524e8475 100644 --- a/tests/python/relax/test_blockbuilder_core.py +++ b/tests/python/relax/test_blockbuilder_core.py @@ -206,9 +206,9 @@ def test_binary_shape_type_deduction(): gv0 = bb.emit_output(lv3) bb.emit_func_output(gv0) - assert isinstance(gv0.checked_type, rx.TensorType) - assert gv0.checked_type.ndim == 1 - assert gv0.checked_type.dtype == "float16" + assert isinstance(gv0.struct_info, rx.TensorStructInfo) + assert gv0.struct_info.ndim == 1 + assert gv0.struct_info.dtype == "float16" def test_emit_match_cast(): @@ -301,11 +301,7 @@ def test_normalize(): # Nested Tuple tuple_2 = rx.Tuple([x, rx.Tuple([x, y])]) bb.normalize(tuple_2) - type_anno0 = x.checked_type - type_anno1 = y.checked_type - assert_structural_equal( - tuple_2.checked_type, rx.TupleType([type_anno0, rx.TupleType([type_anno0, type_anno1])]) - ) + assert isinstance(tuple_2.struct_info, rx.TupleStructInfo) assert isinstance(tuple_2.struct_info.fields[0], rx.TensorStructInfo) assert isinstance(tuple_2.struct_info.fields[1], rx.TupleStructInfo) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 9be0d761c11f..90e2948a320c 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -226,10 +226,6 @@ def test_not_pattern(): assert not no_shape233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32"))) -def test_type_pattern(): - assert wildcard().has_type(rx.TensorType(2, "float32")).match(bindings[0].var) - - def test_dtype_pattern(): dtype = "float16" pattern = has_dtype(dtype) diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index c20f0b268173..a8232fbc8f7d 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -41,28 +41,24 @@ def _check_json_roundtrip(x): def test_var() -> None: v0 = rx.Var("v0") assert v0.name_hint == "v0" - assert v0._checked_type_ is None assert v0.struct_info_ is None shape = [54, 96] v1 = rx.Var("v1", R.Tensor(shape, "float32")) assert v1.name_hint == "v1" for s0, s1 in zip(v1.struct_info.shape, shape): assert s0 == s1 - assert v1.checked_type == rx.TensorType(2, "float32") tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float32")) def test_dataflow_var() -> None: v0 = rx.DataflowVar("v0") assert v0.name_hint == "v0" - assert v0._checked_type_ is None assert v0.struct_info_ is None shape = [54, 96] v1 = rx.DataflowVar("v1", R.Tensor(shape, "float16")) assert v1.name_hint == "v1" - assert v1._checked_type_ == rx.TensorType(2, "float16") assert isinstance(v1, rx.DataflowVar) tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float16")) @@ -116,7 +112,6 @@ def test_match_cast() -> None: assert b0.pattern[0] == m assert b0.pattern[1] == n assert b0.var is not None - assert b0.var.checked_type == rx.ShapeType() # var1: R.Tensor((m, n), "float32") = # match_cast(var0: R.Tensor("float32", ndim=-1), R.Tensor((m, n), "float32")) @@ -128,7 +123,6 @@ def test_match_cast() -> None: assert b1.pattern[0] == m assert b1.pattern[1] == n assert b1.var is not None - assert b1.var.checked_type == rx.TensorType(2, "float32") def test_match_cast() -> None: @@ -234,13 +228,11 @@ def test_shape_expr(): shape_expr = rx.ShapeExpr([10, 20]) assert shape_expr.values[0] == 10 assert shape_expr.values[1] == 20 - assert shape_expr.checked_type == rx.ShapeType(ndim=2) tvm.ir.assert_structural_equal(shape_expr.struct_info, R.Shape((10, 20))) x = rx.Var("v0", R.Tensor((10, 20), "float32")) assert x.struct_info.shape[0] == 10 assert x.struct_info.shape[1] == 20 - assert x.struct_info.shape.checked_type == rx.ShapeType(ndim=2) tvm.ir.assert_structural_equal(x.struct_info.shape.struct_info, R.Shape((10, 20))) m = tir.Var("m", "int32") diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py index f3d2432549e1..46d46e239d3a 100644 --- a/tests/python/relax/test_expr_functor.py +++ b/tests/python/relax/test_expr_functor.py @@ -365,7 +365,7 @@ def visit(f, expr): # check no overloading case basic_mutator = BasicMutator() - # skip normalize GlobalVar since it requires context IRModule to get the checked_type_ + # skip normalize GlobalVar since it requires context IRModule to get the struct_info_ if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar): expr = bb.normalize(expr) assert_structural_equal(visit(basic_mutator, expr), expr)