diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index c8531c88465a..d4ba628d36cf 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -367,6 +367,14 @@ class RelayExprNode : public BaseExprNode { * This value is discarded during serialization. */ mutable Type checked_type_ = Type(nullptr); + + /*! + * \brief Stores the result of structure information of the + * expression that encapsulate both static shape and + * runtime information such as shape. + */ + mutable Optional struct_info_ = Optional(); + /*! * \return The checked_type */ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h new file mode 100644 index 000000000000..8154b1dd86de --- /dev/null +++ b/include/tvm/relax/expr.h @@ -0,0 +1,1003 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAX_EXPR_H_ +#define TVM_RELAX_EXPR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using Expr = RelayExpr; +using ExprNode = RelayExprNode; +using relay::Id; + +/*! + * \brief Base type of all structure information. + * + * StructInfo stores possible structure information + * deduced during compile-time. It encapsulates + * both static type and runtime information such + * as shape. + * + * StructInfo of each non-primitive Expr can be + * deduced during compilation in a "best-effort" manner. + * + * When struct_info appears in function parameter and return + * signatures. They will imply a runtime check that matches + * the structure information with the value. + * + * When it appears in Expr, they follow "assume-semantics", + * which means the compiler will take the deduced information as it is + * and only do best effort prove and checks. + * + * Each struct info can be uniquely erased to a static-type. + * The compiler will still compile the code(with less information) + * when we erase to the static type. + * + * If an StructInfo contains an Expr field, then that field + * must be normalized already through NormalizeArg. + * This invariant will be checked in constructors + * and help us to simplify our assumption + * during struct info deduction. + */ +class StructInfoNode : public Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + static constexpr const char* _type_key = "StructInfo"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 5; + TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object); +}; + +/*! + * \brief Managed reference to StructInfoNode. + * \sa StructInfoNode + */ +class StructInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(StructInfo, ObjectRef, StructInfoNode); +}; + +/*! + * \brief Call corresponds to callable invocation. + * Corresponds to operation in computational graph terminology. + */ +class CallNode : public ExprNode { + public: + /*! + * \brief The operator(function) being invoked + * + * - It can be tvm::Op which corresponds to the primitive operators. + * - It can also be user defined functions (Function, GlobalVar, Var). + */ + Expr op; + + /*! \brief The arguments(inputs) of the call */ + tvm::Array args; + + /*! \brief The additional attributes */ + Attrs attrs; + + /*! + * \brief The structure info arguments of a CallNode. + * sinfo_args is designed to be non-empty only for intrinsic op (e.g., + * call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main + * usage of structure info inference. + */ + Array sinfo_args; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("args", &args); + v->Visit("attrs", &attrs); + v->Visit("sinfo_args", &sinfo_args); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { + // skip sinfo_args check for primitive ops. + equal->MarkGraphNode(); + return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && + (IsPrimitiveOp(op) || equal(sinfo_args, other->sinfo_args)) && + equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(op); + hash_reduce(args); + hash_reduce(attrs); + if (!IsPrimitiveOp(op)) { + hash_reduce(sinfo_args); + } + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.Call"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); +}; + +class Call : public Expr { + public: + /*! + * \brief The constructor + * \param op The operator to be invoked. + * \param args The arguments of the call. + * \param attrs The attributes of the call node. + * \param sinfo_args The structure info arguments passed to a function. + * \param span The source span of the expression. + */ + TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), + Array sinfo_args = Array(), Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); +}; + +/*! + * \brief Returns \p call with the given properties. A null property denotes 'no change'. + * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +Call WithFields(Call call, Optional opt_op = Optional(), + Optional> opt_args = Optional>(), + Optional opt_attrs = Optional(), + Optional> opt_sinfo_args = Optional>(), + Optional opt_span = Optional()); + +/*! + * \brief Condition expression + * + * Unlike traditional statement `if`s, the if evalutes + * to the result of the branch taken. + * + * x = if (true) { 1 } else { 0 }; // x is 1 + * y = if (false) { 1 } else { 0 }; // y is 0 + * + * \note This is similar to C's ternary operator. + */ +class IfNode : public ExprNode { + public: + /*! \brief The condition. */ + Expr cond; + /*! \brief The expression evaluated when condition is true. */ + Expr true_branch; + /*! \brief The expression evaluated when condition is false */ + Expr false_branch; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cond", &cond); + v->Visit("true_branch", &true_branch); + v->Visit("false_branch", &false_branch); + v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); + v->Visit("span", &span); + } + + bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(cond, other->cond) && equal(true_branch, other->true_branch) && + equal(false_branch, other->false_branch) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(cond); + hash_reduce(true_branch); + hash_reduce(false_branch); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.If"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); +}; + +class If : public Expr { + public: + /*! + * \brief The constructor + * \param cond The condition of a if node. + * \param true_branch The fall through branch + * \param false_branch The branch for execution when condition is false. + * \param span The source span of the expression. + */ + TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode); +}; + +/*! + * \brief Returns \p if_expr with the given properties. A null property denotes 'no change'. + * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +If WithFields(If if_expr, Optional opt_cond = Optional(), + Optional opt_true_branch = Optional(), + Optional opt_false_branch = Optional(), + Optional opt_span = Optional()); + +/*! \brief Tuple container */ +class TupleNode : public ExprNode { + public: + /*! \brief the fields of the tuple */ + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("fields", &fields); + v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); + v->Visit("span", &span); + } + + bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from fields. + return equal(fields, other->fields); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } + + static constexpr const char* _type_key = "relax.expr.Tuple"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); +}; + +class Tuple : public Expr { + public: + /*! + * \brief The constructor + * \param fields The fields of a tuple. + * \param span The source span of the expression. + */ + TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode); +}; + +/*! + * \brief Returns \p tuple with the given properties. A null property denotes 'no change'. + * Returns \p tuple if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), + Optional opt_span = Optional()); + +/*! \brief Get index-th field out of a tuple. */ +class TupleGetItemNode : public ExprNode { + public: + /*! \brief The tuple Expression */ + Expr tuple; + /*! \brief which value to get */ + int index; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tuple_value", &tuple); + v->Visit("index", &index); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const { + // struct info can be deterministically tuple and index. + return equal(tuple, other->tuple) && equal(index, other->index); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(tuple); + hash_reduce(index); + } + + static constexpr const char* _type_key = "relax.expr.TupleGetItem"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); +}; + +class TupleGetItem : public Expr { + public: + /*! + * \brief The constructor + * \param tuple The tuple to get an element from. + * \param index The index for extracting a value in the tuple. + * \param span The source span of the expression. + */ + TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode); +}; + +/*! + * \brief Returns \p tuple_get_item with the given properties. A null property denotes 'no change'. + * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), + Optional opt_index = Optional(), + Optional opt_span = Optional()); + +/*! + * \brief Base type of all (non-function) leaf Exprs. + * \sa Expr + */ +class LeafExprNode : public ExprNode { + public: + static constexpr const char* _type_key = "relax.expr.LeafExpr"; + static constexpr const uint32_t _type_child_slots = 7; + TVM_DECLARE_BASE_OBJECT_INFO(LeafExprNode, ExprNode); +}; + +/*! + * \brief Managed reference to BaseExprNode. + * \sa LeafExprNode + */ +class LeafExpr : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(LeafExpr, Expr, LeafExprNode); +}; + +/*! \brief A shape expression which allows users to construct a shape containing PrimExpr. + */ +class ShapeExprNode : public LeafExprNode { + public: + /*! The values of the shape expression. */ + Array values; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("values", &values); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from values. + return equal(values, other->values); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(values); } + + static constexpr const char* _type_key = "relax.expr.ShapeExpr"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeExprNode, LeafExprNode); +}; + +class ShapeExpr : public LeafExpr { + public: + TVM_DLL explicit ShapeExpr(Array values, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, LeafExpr, ShapeExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode); +}; + +/*! \brief The variable class for all Relax bindings. */ +class VarNode : public LeafExprNode { + public: + /*! \brief The identifier of the variable, which is used for comparing stable equality across + * transformations. */ + Id vid; + + /*! \return The name hint of the variable */ + const String& name_hint() const { return vid->name_hint; } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("vid", &vid); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(vid); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.Var"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 2; + TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode); +}; + +class Var : public LeafExpr { + public: + TVM_DLL explicit Var(String name_hint, Optional struct_info_annotation, + Span span = Span()) + : Var(Id(name_hint), struct_info_annotation, span) {} + + TVM_DLL explicit Var(Id vid, Optional struct_info_annotation, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); +}; + +/*! \brief A sub-type of the variable node used to mark dataflow variables from + * normal visible "function local" bindings. + */ +class DataflowVarNode : public VarNode { + public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("vid", &vid); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(vid); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.DataflowVar"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarNode, VarNode); +}; + +class DataflowVar : public Var { + public: + TVM_DLL explicit DataflowVar(String name_hint, Optional struct_info_annotation, + Span span = Span()) + : DataflowVar(Id(name_hint), struct_info_annotation, span) {} + + TVM_DLL explicit DataflowVar(Id vid, Optional struct_info_annotation, + Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode); +}; + +/*! + * \brief Constant tensor. + * + * \note Scalar constants are represented by ndim-0 constant tensors. + */ +class ConstantNode : public LeafExprNode { + public: + /*! \brief The data of the tensor */ + runtime::NDArray data; + + /*! \return The corresponding tensor type of the data */ + TensorType tensor_type() const; + + /*! \return Whether it is scalar(ndim-0 tensor) */ + bool is_scalar() const { return data->ndim == 0; } + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("data", &data); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(data, other->data); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); } + + static constexpr const char* _type_key = "relax.expr.Constant"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, LeafExprNode); +}; + +class Constant : public LeafExpr { + public: + /*! + * \brief The constructor + * \param data The data of the constant tensor. + * \param span The source span of the expression. + */ + TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Constant, LeafExpr, ConstantNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode); +}; + +/*! + * \brief PrimValue. + * + * Expression representing a TIR POD expression. + */ +class PrimValueNode : public LeafExprNode { + public: + /*! \brief The prim expr representing the value */ + PrimExpr value; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const PrimValueNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } + + static constexpr const char* _type_key = "relax.expr.PrimValue"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimValueNode, LeafExprNode); +}; + +/*! + * \brief Managed reference to PrimValueNode + * \sa PrimValeNode + */ +class PrimValue : public LeafExpr { + public: + /*! + * \brief The constructor + * \param value The value input. + * \param span The source span of the expression. + */ + TVM_DLL explicit PrimValue(PrimExpr value, Span span = Span()); + + /*! + * \brief Create a int64 prim value. + * \param value The input value. + * \param span The source span of the expression. + * \return The created prim value. + */ + TVM_DLL static PrimValue Int64(int64_t value, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(PrimValue, LeafExpr, PrimValueNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode); +}; + +/*! + * \brief Represent a string literal constant. + */ +class StringImmNode : public LeafExprNode { + public: + /*! \brief The data value. */ + String value; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } + + static constexpr const char* _type_key = "relax.expr.StringImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, LeafExprNode); +}; + +/*! + * \brief Managed reference to StringImm + * \sa StringImmNode + */ +class StringImm : public LeafExpr { + public: + /*! + * \brief The constructor + * \param value The value input. + * \param span The source span of the expression. + */ + TVM_DLL explicit StringImm(String value, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(StringImm, LeafExpr, StringImmNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode); +}; + +/*! + * \brief Represent a data type constant. + */ +class DataTypeImmNode : public LeafExprNode { + public: + /*! \brief The data value. */ + DataType value; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const DataTypeImmNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } + + static constexpr const char* _type_key = "relax.expr.DataTypeImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataTypeImmNode, LeafExprNode); +}; + +/*! + * \brief Managed reference to DataTypeImm + * \sa DataTypeImmNode + */ +class DataTypeImm : public LeafExpr { + public: + /*! + * \brief The constructor + * \param value The value input. + * \param span The source span of the expression. + */ + TVM_DLL explicit DataTypeImm(DataType value, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DataTypeImm, LeafExpr, DataTypeImmNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode); +}; + +/*! \brief The base class of a variable binding in Relax. */ +class BindingNode : public Object { + public: + /*! \brief The return variable to bound to. */ + Var var; + mutable Span span; + + static constexpr const char* _type_key = "relax.expr.Binding"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object); +}; + +class Binding : public ObjectRef { + protected: + Binding() = default; + + public: + explicit Binding(ObjectPtr n) : ObjectRef(n) {} + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding); + const BindingNode* operator->() const { return static_cast(data_.get()); } + const BindingNode* get() const { return operator->(); } + using ContainerType = BindingNode; +}; + +/*! + * \brief Runtime-match the value to the struct info. + * + * This operation does runtime check, populates the un-defined symbolic shape vars + * and vars in struct_info in first occurance, and insert equality assertions in + * other cases. + */ +class MatchCastNode : public BindingNode { + public: + /*! \brief The input value to match cast. */ + Expr value; + /*! \brief The struct info pattern to match to. */ + StructInfo struct_info; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("struct_info", &struct_info); + v->Visit("span", &span); + } + + bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const { + // NOTE: pattern can contain ShapeExpr which defines the vars + return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) && + equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { + // NOTE: pattern can contain ShapeExpr which defines the vars + hash_reduce.DefHash(var); + hash_reduce.DefHash(struct_info); + hash_reduce(value); + } + + static constexpr const char* _type_key = "relax.expr.MatchCast"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode); +}; + +/*! + * \brief Managed reference to MatchCastNode. + * \sa MatchCastNode + */ +class MatchCast : public Binding { + public: + TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(MatchCast, Binding, MatchCastNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode); +}; + +class VarBindingNode : public BindingNode { + public: + /*! \brief The binding value. */ + Expr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("span", &span); + } + + bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { + return equal.DefEqual(var, other->var) && equal(value, other->value); + } + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(var); + hash_reduce(value); + } + static constexpr const char* _type_key = "relax.expr.VarBinding"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(VarBindingNode, BindingNode); +}; + +class VarBinding : public Binding { + public: + TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode); +}; + +class BindingBlockNode : public Object { + public: + mutable Span span; + Array bindings; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("span", &span); + v->Visit("bindings", &bindings); + } + + bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const { + return equal(bindings, other->bindings); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + + static constexpr const char* _type_key = "relax.expr.BindingBlock"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingBlockNode, Object); +}; + +class BindingBlock : public ObjectRef { + public: + TVM_DLL explicit BindingBlock(Array bindings, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode); +}; + +class DataflowBlock; +class DataflowBlockNode : public BindingBlockNode { + public: + bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { + return equal(bindings, other->bindings); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + + static constexpr const char* _type_key = "relax.expr.DataflowBlock"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockNode, BindingBlockNode); +}; + +class DataflowBlock : public BindingBlock { + public: + TVM_DLL explicit DataflowBlock(Array bindings, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode); +}; + +/*! \brief A sequence of blocks followed by an expression. + * + * The order of blocks enforces scoping and ordering. + */ +class SeqExprNode : public ExprNode { + public: + Array blocks; + Expr body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("blocks", &blocks); + v->Visit("body", &body); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const SeqExprNode* other, SEqualReducer equal) const { + return equal(blocks, other->blocks) && equal(body, other->body) && + equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(blocks); + hash_reduce(body); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.SeqExpr"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(SeqExprNode, ExprNode); +}; + +class SeqExpr : public Expr { + public: + TVM_DLL explicit SeqExpr(Array blocks, Expr body, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode); +}; + +/*! \brief A Relax function. */ +class FunctionNode : public BaseFuncNode { + public: + /*! \brief The parameters to the function. */ + Array params; + /*! \brief The body of the function. */ + Expr body; + /*! \brief The return type of the function. */ + StructInfo ret_struct_info; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("body", &body); + v->Visit("ret_struct_info", &ret_struct_info); + v->Visit("attrs", &attrs); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal.DefEqual(params, other->params) && equal(body, other->body) && + equal(ret_struct_info, other->ret_struct_info) && equal(attrs, other->attrs) && + equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce.DefHash(params); + hash_reduce(body); + hash_reduce(ret_struct_info); + hash_reduce(attrs); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.Function"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); +}; + +class Function : public BaseFunc { + public: + TVM_DLL explicit Function(Array params, Expr body, Optional ret_struct_info, + DictAttrs attrs = NullValue(), Span span = Span()); + + /*! + * \brief Mimics the constructor but without body Expr. + * \note ret_struct_info is required, since it can not deduced by the body + */ + TVM_DLL static Function CreateEmpty(Array params, StructInfo ret_struct_info, + DictAttrs attrs = NullValue(), Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); +}; + +// TODO(@sunggg): Investigate the exact usage of kComposite, kPartitionedFromPattern, and +// kPrimitive. +namespace attr { +/*! \brief Mark the function as a primitive function. */ +constexpr const char* kPrimitive = "Primitive"; +/*! + * \brief Indicate the codegen that should be used for building this function. + * When this is unset or set to "default", the default compilation pipeline will be used. + */ +constexpr const char* kCodegen = "Codegen"; +/*! \brief Treat the function as a composite operator. */ +constexpr const char* kComposite = "Composite"; +/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ +constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; +} // namespace attr + +/*! \brief The extern function, which can represent packed function. */ +class ExternFuncNode : public BaseFuncNode { + public: + /*! \brief The name of global symbol. */ + String global_symbol; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("global_symbol", &global_symbol); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const { + return equal(global_symbol, other->global_symbol) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(global_symbol); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.ExternFunc"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode); +}; + +class ExternFunc : public BaseFunc { + public: + TVM_DLL ExternFunc(String global_symbol, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode); +}; + +/*! + * \brief Get the shape of Expr. + * \param expr The input expr. + * \return The corresonding shape. + * + * \note This function requires expr to be normalized. + * The function will report an error if expr's StructInfo is not TensorStructInfo. + * It will try to return symbolic function when possible. If the tensor do not + * have a compile-time symbolic shape, the function will then choose to return + * Call(relax.op.shape_of, [expr]). + */ +TVM_DLL Expr GetShapeOf(const Expr& expr); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_EXPR_H_ diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h new file mode 100644 index 000000000000..9c20a524353a --- /dev/null +++ b/include/tvm/relax/type.h @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/type.h + * \brief Relax Types. + */ +#ifndef TVM_RELAX_TYPE_H_ +#define TVM_RELAX_TYPE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! \brief Indicates the number of dimensions of a tensor is unknown at compile time. */ +static constexpr int kUnknownNDim = -1; + +class ShapeTypeNode : public TypeNode { + public: + /*! \brief size of the shape. */ + int ndim; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { + return equal(ndim, other->ndim); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(ndim); } + + static constexpr const char* _type_key = "relax.ShapeType"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode); +}; + +class ShapeType : public Type { + public: + // TODO(relax-team): remove the default value later. + TVM_DLL ShapeType(int ndim = kUnknownNDim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode); +}; + +class ObjectTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const ObjectTypeNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.ObjectType"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjectTypeNode, TypeNode); +}; + +class ObjectType : public Type { + public: + TVM_DLL ObjectType(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectType, Type, ObjectTypeNode); +}; + +class DynTensorTypeNode : public BaseTensorTypeNode { + public: + /*! + * \brief The number of dimensions of the tensor, use -1 to denote tensor with unknwon number of + * dimensions. + */ + int ndim; + /*! \brief The content data type, use void to denote the dtype is unknown. */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("ndim", &ndim); + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + bool SEqualReduce(const DynTensorTypeNode* other, SEqualReducer equal) const { + return equal(ndim, other->ndim) && equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(ndim); + hash_reduce(dtype); + } + + inline bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + inline bool IsUnknownDtype() const { return dtype.is_void(); } + + static constexpr const char* _type_key = "relax.DynTensorType"; + TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode); +}; + +/*! + * \brief Managed reference to DynTensorTypeNode. + * \sa DynTensorTypeNode. + */ +class DynTensorType : public Type { + public: + /*! + * \brief Constructor. + * \param ndim The number of dimensions of the tensor. + * \param dtype The runtime dtype of the tensor's elements. + * \param span The span. + */ + TVM_DLL DynTensorType(int ndim, DataType dtype, Span span = Span()); + + /*! + * \brief Create a DynTensorType with unknown ndim. + */ + TVM_DLL static DynTensorType CreateUnknownNDim(DataType dtype, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode); +}; + +class PackedFuncTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const PackedFuncTypeNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.PackedFuncType"; + TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncTypeNode, TypeNode); +}; + +class PackedFuncType : public Type { + public: + TVM_DLL PackedFuncType(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PackedFuncType, Type, PackedFuncTypeNode); +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_TYPE_H_