Skip to content

Commit

Permalink
[REFACTOR] StructInfo M3: MatchShape=>MatchCast (apache#323)
Browse files Browse the repository at this point in the history
* Introduce match cast, and code changes along

* add match_cast parser support (apache#9)

* Match cast support for VMShapeLower CanonicalizeBinding

* Remove `match_shape` (apache#12)

* Refactor ExprVisitor/Mutator to consider Expr in StructInfo.

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
  • Loading branch information
tqchen and Hzfengsy authored Dec 29, 2022
1 parent e332285 commit 343a1e7
Show file tree
Hide file tree
Showing 52 changed files with 789 additions and 651 deletions.
10 changes: 5 additions & 5 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ class BlockBuilderNode : public Object {
virtual Var Emit(Expr expr, String name_hint = "") = 0;

/*!
* \brief Emit a MatchShape.
* \param value The value of the MatchShape to be emitted.
* \param pattern The pattern of the MatchShape to be emitted.
* \brief Emit a MatchCast.
* \param value The input value.
* \param struct_info The struct info to be matched.
* \param name_hint Name hint for the bound variable.
* \return The variable bound to the MatchShape.
* \return The variable bound to the MatchCast.
*/
virtual Var EmitMatchShape(Expr value, Array<PrimExpr> pattern, String name_hint = "") = 0;
virtual Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint = "") = 0;

/*!
* \brief Generate an output for the current dataflow block.
Expand Down
58 changes: 32 additions & 26 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,12 +531,10 @@ class Constant : public Expr {
/*! \brief The base class of a variable binding in Relax. */
class BindingNode : public Object {
public:
/*! \brief The return variable to bound to. */
Var var;
mutable Span span;

void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
bool SEqualReduce(const BindingNode* other, SEqualReducer equal) const { return true; }
void SHashReduce(SHashReducer hash_reduce) const {}

static constexpr const char* _type_key = "relax.expr.Binding";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand All @@ -555,51 +553,61 @@ class Binding : public ObjectRef {
using ContainerType = BindingNode;
};

/*! \brief Symbolic shape match, binds the variable of the lhs with the rhs. */
class MatchShape;
class MatchShapeNode : public BindingNode {
/*!
* \brief Runtime-match the value to the struct info.
*
* This operation does runtime check, populates the un-defined symbolic shape vars
* and vars in struct_info in first occurance, and insert equality assertions in
* other cases.
*/
class MatchCastNode : public BindingNode {
public:
/*! \brief The input value to match cast. */
Expr value;
Array<PrimExpr> pattern;
Var var;
/*! \brief The struct info pattern to match to. */
StructInfo struct_info;

void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("pattern", &pattern);
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("struct_info", &struct_info);
v->Visit("span", &span);
}

bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const {
bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const {
// NOTE: pattern can contain ShapeExpr which defines the vars
return equal(value, other->value) && equal.DefEqual(pattern, other->pattern) &&
equal.DefEqual(var, other->var);
return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) &&
equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const {
// NOTE: pattern can contain ShapeExpr which defines the vars
hash_reduce(value);
hash_reduce.DefHash(pattern);
hash_reduce.DefHash(var);
hash_reduce.DefHash(struct_info);
hash_reduce(value);
}

static constexpr const char* _type_key = "relax.expr.MatchShape";
static constexpr const char* _type_key = "relax.expr.MatchCast";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(MatchShapeNode, BindingNode);
TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode);
};

class MatchShape : public Binding {
/*!
* \brief Managed reference to MatchCastNode.
* \sa MatchCastNode
*/
class MatchCast : public Binding {
public:
TVM_DLL explicit MatchShape(Expr value, Array<PrimExpr> pattern, Var var, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchShapeNode);
TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(MatchCast, Binding, MatchCastNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode);
};

class VarBinding;
class VarBindingNode : public BindingNode {
public:
Var var;
/*! \brief The binding value. */
Expr value;

void VisitAttrs(AttrVisitor* v) {
Expand Down Expand Up @@ -628,8 +636,6 @@ class VarBinding : public Binding {
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode);
};

class BindingBlock;

class BindingBlockNode : public Object {
public:
mutable Span span;
Expand Down
139 changes: 118 additions & 21 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@
#include <tvm/node/functor.h>
#include <tvm/relax/block_builder.h>
#include <tvm/relax/expr.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/struct_info_functor.h>
#include <tvm/relay/op.h>
#include <tvm/tir/function.h>

Expand Down Expand Up @@ -213,7 +212,7 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
virtual void VisitBinding(const Binding& binding);
// specific leaf level visitor functions
virtual void VisitBinding_(const VarBindingNode* binding);
virtual void VisitBinding_(const MatchShapeNode* binding);
virtual void VisitBinding_(const MatchCastNode* binding);
// second level dispatching based on binding value type.
// these dispatching functions get called from first-level dispatch on VarBinding
virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
Expand Down Expand Up @@ -244,6 +243,23 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
* \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
*/
virtual void VisitVarDef(const Var& var);

/*!
* \brief Visit struct_info may recursively contain Expr/PrimExpr.
*
* By default, this function recurse into struct info such as
* TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr
* accordingly. It does not recurse into FunctionStructInfo as it does
* not contain Expr defined in the current scope.
*
* Pass writers can overload this function to change to other behaviors.
* For example, if we are not interested in Expr in StructInfo, we can
* override this function by a no-op.
*
* \param struct_info Input struct info field.
*/
virtual void VisitExprDepStructInfoField(const StructInfo& struct_info);

// specific leaf level visitor functions
virtual void VisitVarDef_(const VarNode* var);
virtual void VisitVarDef_(const DataflowVarNode* var);
Expand All @@ -258,6 +274,30 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
tvm::NodeFunctor<void(const ObjectRef& n, ExprVisitor* self, const VarBindingNode* binding)>;
// initialize the vtable.
static VisitBindingVTable InitVisitBindingVTable();
/*!
* \brief Private internal struct info field visitor.
*
* Support default visiting of struct info field and recursive into
* their Expr fields.
*
* We use component instead of sub-classing so there can be other
* joint inheritance between ExprVisitor and StructInfoVisitor.
*/
class DefaultStructInfoFieldVisitor : public StructInfoVisitor {
public:
explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent);

// Override defaults in struct info visitor.
void VisitStructInfoExprField(const Expr& expr) final;
void VisitStructInfoExprField(const PrimExpr& expr) final;
void VisitStructInfo_(const FuncStructInfoNode* op) final;

private:
ExprVisitor* parent_;
};
// This visitor is not visible to child classes and only
// used to supportd default visiting behavior.
DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this};
};

void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
Expand Down Expand Up @@ -309,6 +349,64 @@ class ExprMutatorBase : public ExprFunctor<Expr(const Expr&)> {
* Can be overloaded to transform the shape expressions.
*/
virtual PrimExpr VisitPrimExpr(const PrimExpr& expr);

/*!
* \brief Visit struct_info that may recursively contain Expr/PrimExpr.
*
* By default, this function recurse into struct info such as
* TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr
* accordingly. It does not recurse into FunctionStructInfo as it does
* not contain Expr defined in the current scope.
*
* Pass writers can overload this function to change to other behaviors.
* For example, if in Expr in StructInfo won't change, we can
* override this function by an identity function.
*
* \param struct_info Input struct info field.
* \return The updated struct info.
*/
virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info);

protected:
/*!
* \brief Check whether VisitExprDepStructInfoField change struct_info.
* \return Whether struct info changed.
* \note This function is used by mutator implementations to check if
* previous Expr update will trigger a change in struct_info.
* If change is detected, the implementation can generate a fresh
* node without struct_info, and trigger normalizer to re-derive.
*/
bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef& struct_info) {
if (const StructInfoNode* sinfo = struct_info.as<StructInfoNode>()) {
return this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)).same_as(struct_info);
} else {
return true;
}
}

private:
/*!
* \brief Private internal struct info field visitor to support
* Default visiting of struct info field and recursive into their Expr fields.
*
* We use component instead of sub-classing so there can be other
* joint inheritance between ExprMutator and StructInfoMutator.
*/
class DefaultStructInfoFieldMutator : public StructInfoMutator {
public:
explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent);

// Override defaults in struct info visitor.
Expr VisitStructInfoExprField(const Expr& expr) final;
PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final;
StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final;

private:
ExprMutatorBase* parent_;
};
// This visitor is not visible to child classes and only
// used to supportd default visiting behavior.
DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this};
};

/*!
Expand All @@ -324,7 +422,6 @@ class ExprMutator : public ExprMutatorBase {

ExprMutator(Optional<IRModule> mod = NullOpt) { builder_ = BlockBuilder::Create(mod); }
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const TupleNode* op) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const DataflowVarNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override;
Expand All @@ -338,7 +435,7 @@ class ExprMutator : public ExprMutatorBase {
virtual void VisitBinding(const Binding& binding);
// specific leaf level visitor functions
virtual void VisitBinding_(const VarBindingNode* binding);
virtual void VisitBinding_(const MatchShapeNode* binding);
virtual void VisitBinding_(const MatchCastNode* binding);
// second level dispatching based on binding value type.
// these dispatching functions get called from first-level dispatch on VarBinding
virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
Expand Down Expand Up @@ -484,9 +581,9 @@ class PyExprVisitorNode : public Object, public ExprVisitor {
/*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)`
* function. */
PackedFunc f_visit_var_binding_{nullptr};
/*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)`
/*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)`
* function. */
PackedFunc f_visit_match_shape_{nullptr};
PackedFunc f_visit_match_cast_{nullptr};
/*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)`
* function. */
PackedFunc f_visit_binding_block{nullptr};
Expand Down Expand Up @@ -523,8 +620,8 @@ class PyExprVisitorNode : public Object, public ExprVisitor {
void VisitBinding_(const VarBindingNode* binding)
PY_EXPR_VISITOR_DEFAULT(GetRef<VarBinding>(binding), f_visit_var_binding_,
ExprVisitor::VisitBinding_(binding));
void VisitBinding_(const MatchShapeNode* binding)
PY_EXPR_VISITOR_DEFAULT(GetRef<MatchShape>(binding), f_visit_match_shape_,
void VisitBinding_(const MatchCastNode* binding)
PY_EXPR_VISITOR_DEFAULT(GetRef<MatchCast>(binding), f_visit_match_cast_,
ExprVisitor::VisitBinding_(binding));

void VisitBindingBlock(const BindingBlock& block)
Expand Down Expand Up @@ -602,7 +699,7 @@ class PyExprVisitor : public ObjectRef {
* \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`.
* \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode*
* binding)`.
* \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode*
* \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode*
* binding)`.
* \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock&
* block)`.
Expand All @@ -624,7 +721,7 @@ class PyExprVisitor : public ObjectRef {
PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_,
PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_,
PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding,
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_,
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_,
PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_,
PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_,
PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) {
Expand All @@ -649,7 +746,7 @@ class PyExprVisitor : public ObjectRef {
n->f_visit_op_ = f_visit_op_;
n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_;
n->f_visit_var_binding_ = f_visit_var_binding_;
n->f_visit_match_shape_ = f_visit_match_shape_;
n->f_visit_match_cast_ = f_visit_match_cast_;
n->f_visit_binding_block_ = f_visit_binding_block_;
n->f_visit_dataflow_block_ = f_visit_dataflow_block_;
n->f_visit_var_def_ = f_visit_var_def_;
Expand Down Expand Up @@ -702,9 +799,9 @@ class PyExprMutatorNode : public Object, public ExprMutator {
/*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)`
* function. */
PackedFunc f_visit_var_binding_{nullptr};
/*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)`
/*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)`
* function. */
PackedFunc f_visit_match_shape_{nullptr};
PackedFunc f_visit_match_cast_{nullptr};
/*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)`
* function. */
PackedFunc f_visit_binding_block{nullptr};
Expand Down Expand Up @@ -748,9 +845,9 @@ class PyExprMutatorNode : public Object, public ExprMutator {
ExprMutator::VisitBinding_(binding);
}

void VisitBinding_(const MatchShapeNode* binding) {
if (f_visit_match_shape_ != nullptr)
f_visit_match_shape_(GetRef<MatchShape>(binding));
void VisitBinding_(const MatchCastNode* binding) {
if (f_visit_match_cast_ != nullptr)
f_visit_match_cast_(GetRef<MatchCast>(binding));
else
ExprMutator::VisitBinding_(binding);
}
Expand Down Expand Up @@ -866,7 +963,7 @@ class PyExprMutator : public ObjectRef {
* \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`.
* \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode*
* binding)`.
* \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode*
* \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode*
* binding)`.
* \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock&
* block)`.
Expand All @@ -889,7 +986,7 @@ class PyExprMutator : public ObjectRef {
PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_,
PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_,
PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding,
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_,
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_,
PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_,
PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_,
PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) {
Expand All @@ -911,7 +1008,7 @@ class PyExprMutator : public ObjectRef {
n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_;
n->f_visit_binding = f_visit_binding;
n->f_visit_var_binding_ = f_visit_var_binding_;
n->f_visit_match_shape_ = f_visit_match_shape_;
n->f_visit_match_cast_ = f_visit_match_cast_;
n->f_visit_binding_block = f_visit_binding_block;
n->f_visit_binding_block_ = f_visit_binding_block_;
n->f_visit_dataflow_block_ = f_visit_dataflow_block_;
Expand Down
Loading

0 comments on commit 343a1e7

Please sign in to comment.