Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 1 addition & 47 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,38 +400,13 @@ TVM_DLL PrimExpr operator~(PrimExpr a);
*/
class RelaxExprNode : public BaseExprNode {
public:
/*!
* \brief Stores the result of type inference(type checking).
*
* \note This can be undefined before type inference.
* This value is discarded during serialization.
*/
mutable Type checked_type_ = Type(nullptr);

/*!
* \brief Stores the result of structure information of the
* expression that encapsulate both static shape and
* runtime information such as shape.
*/
mutable Optional<ObjectRef> struct_info_ = Optional<ObjectRef>();

/*!
* \return The checked_type
*/
inline const Type& checked_type() const;
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template <typename TTypeNode>
inline const TTypeNode* type_as() const;

static constexpr const char* _type_key = "RelaxExpr";
static constexpr const uint32_t _type_child_slots = 22;
TVM_DECLARE_BASE_OBJECT_INFO(RelaxExprNode, BaseExprNode);
Expand Down Expand Up @@ -463,7 +438,6 @@ class GlobalVarNode : public RelaxExprNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
v->Visit("struct_info_", &struct_info_);
}

Expand All @@ -487,7 +461,7 @@ class GlobalVarNode : public RelaxExprNode {
*/
class GlobalVar : public RelaxExpr {
public:
TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {});
TVM_DLL explicit GlobalVar(String name_hint, Span span = {});

TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelaxExpr, GlobalVarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode);
Expand Down Expand Up @@ -747,26 +721,6 @@ class Range : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
};

// implementations
inline const Type& RelaxExprNode::checked_type() const {
ICHECK(checked_type_.defined()) << "internal error: the type checker has "
<< "not populated the checked_type "
<< "field for " << GetRef<RelaxExpr>(this);
return this->checked_type_;
}

template <typename TTypeNode>
inline const TTypeNode* RelaxExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
ICHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>();
ICHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get "
<< checked_type_->GetTypeKey();
return node;
}

namespace ffi {
// Type traits to enable automatic conversion into IntImm, Integer, and Bool
// when called through the FFI
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ constexpr const char* kGlobalSymbol = "global_symbol";
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions share the same type system(via checked_type)
* All of the functions share the same type system
* to support cross variant calls.
*
* \sa BaseFunc
Expand Down
31 changes: 0 additions & 31 deletions include/tvm/relax/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class AndPattern;
class NotPattern;
class ShapePattern;
class StructInfoPattern;
class TypePattern;
class DataTypePattern;
class AttrPattern;
class SameShapeConstraint;
Expand Down Expand Up @@ -116,8 +115,6 @@ class DFPattern : public ObjectRef {
TVM_DLL AttrPattern HasAttr(const Map<String, Any>& attrs) const;
/*! \brief Syntatic Sugar for creating a StructInfoPattern */
TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const;
/*! \brief Syntatic Sugar for creating a TypePattern */
TVM_DLL TypePattern HasType(const Type& type) const;
/*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */
TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const;
/*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */
Expand Down Expand Up @@ -742,34 +739,6 @@ class WildcardPattern : public DFPattern {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode);
};

/*!
* \brief Pattern for matching a certain type.
* \sa TypePattern
*/
class TypePatternNode : public DFPatternNode {
public:
DFPattern pattern; /*!< The pattern to match */
Type type; /*!< The type to match */

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("type", &type);
}

static constexpr const char* _type_key = "relax.dpl.TypePattern";
TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode);
};

/*!
* \brief Managed reference to TypePatternNode.
* \sa TypePatternNode
*/
class TypePattern : public DFPattern {
public:
TVM_DLL TypePattern(DFPattern pattern, Type type);
TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
};

/*!
* \brief Pattern for matching a certain struct info.
* \sa StructInfoPattern
Expand Down
3 changes: 0 additions & 3 deletions include/tvm/relax/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const StructInfoPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;

Expand Down Expand Up @@ -132,7 +131,6 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(StructInfoPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode);
Expand Down Expand Up @@ -167,7 +165,6 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
void VisitDFPattern_(const StructInfoPatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;

Expand Down
14 changes: 0 additions & 14 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ class CallNode : public ExprNode {
v->Visit("attrs", &attrs);
v->Visit("sinfo_args", &sinfo_args);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -224,7 +223,6 @@ class TupleNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("_checked_type_", &checked_type_);
v->Visit("struct_info_", &struct_info_);
v->Visit("span", &span);
}
Expand Down Expand Up @@ -291,7 +289,6 @@ class TupleGetItemNode : public ExprNode {
v->Visit("tuple_value", &tuple);
v->Visit("index", &index);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -362,7 +359,6 @@ class ShapeExprNode : public LeafExprNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("values", &values);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -399,7 +395,6 @@ class VarNode : public LeafExprNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("vid", &vid);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -440,7 +435,6 @@ class DataflowVarNode : public VarNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("vid", &vid);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -492,7 +486,6 @@ class ConstantNode : public LeafExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -540,7 +533,6 @@ class PrimValueNode : public LeafExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -591,7 +583,6 @@ class StringImmNode : public LeafExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -634,7 +625,6 @@ class DataTypeImmNode : public LeafExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -824,7 +814,6 @@ class SeqExprNode : public ExprNode {
v->Visit("blocks", &blocks);
v->Visit("body", &body);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -889,7 +878,6 @@ class IfNode : public ExprNode {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
v->Visit("_checked_type_", &checked_type_);
v->Visit("struct_info_", &struct_info_);
v->Visit("span", &span);
}
Expand Down Expand Up @@ -966,7 +954,6 @@ class FunctionNode : public BaseFuncNode {
v->Visit("ret_struct_info", &ret_struct_info);
v->Visit("attrs", &attrs);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -1071,7 +1058,6 @@ class ExternFuncNode : public BaseFuncNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("global_symbol", &global_symbol);
v->Visit("struct_info_", &struct_info_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
/*!
* \brief A mutator works in unnormalized form.
*
* ExprMutatorBase expects input AST to be in the unnormalized form, i.e., checked_type_ and shape_
* ExprMutatorBase expects input AST to be in the unnormalized form, i.e., struct_info_
* of expressions can be nullptr, and the expressions may nest(and as a result the AST is not in
* ANF).
*/
Expand Down Expand Up @@ -414,7 +414,7 @@ class ExprMutatorBase : public ExprFunctor<Expr(const Expr&)> {
* \brief A mutator works in normal form.
*
* ExprMutator expects input AST to be in the normal form, i.e., the expressions are normalized(no
* nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are
* nesting and hence the AST is in ANF), and all struct_info_ of expressions are
* available.
*/
class ExprMutator : public ExprMutatorBase {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ TVM_DLL Pass AttachGlobalSymbol();

/*!
* \brief Transform Relax IR to normal form: transform AST to A-normal form, and fill the
* checked_type_ and shape_ of expressions.
* struct_info_ of expressions.
*
* \return The Pass.
*/
Expand Down
1 change: 0 additions & 1 deletion include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ class PrimFuncNode : public BaseFuncNode {
v->Visit("buffer_map", &buffer_map);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
Expand Down
19 changes: 2 additions & 17 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ..runtime import Object, Scriptable
from . import _ffi_api
from .base import Node, Span
from .type import Type


class BaseExpr(Node):
Expand All @@ -45,20 +44,6 @@ class PrimExpr(BaseExpr):
class RelaxExpr(BaseExpr):
"""Base class of all non-primitive expressions."""

@property
def checked_type(self):
"""Get the checked type of tvm.relax.Expr.

Returns
-------
checked_type : tvm.ir.Type
The checked type.
"""
ret = self._checked_type_
if ret is None:
raise ValueError("The type checker has not populated the checked_type for this node")
return ret

@property
def struct_info(self) -> Optional["tvm.relax.StructInfo"]:
"""Get the struct info field
Expand Down Expand Up @@ -86,8 +71,8 @@ class GlobalVar(RelaxExpr):

name_hint: str

def __init__(self, name_hint: str, type_annot: Optional[Type] = None):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, type_annot)
def __init__(self, name_hint: str):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint)

def __call__(self, *args: RelaxExpr) -> BaseExpr:
"""Call the global variable.
Expand Down
Loading
Loading