Skip to content

Commit 148737b

Browse files
authored
[IR] Compact Functor vtable (#17731)
This PR add a finalize routine to optionally compact functor vtable dynamically. Also updates child_slots for key types to make sure the IR node type index stay within range and such compact happens.
1 parent 6afa62c commit 148737b

File tree

14 files changed

+53
-8
lines changed

14 files changed

+53
-8
lines changed

include/tvm/arith/iter_affine_map.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class IterMapExprNode : public PrimExprNode {
6969
void VisitAttrs(tvm::AttrVisitor* v) {}
7070

7171
static constexpr const char* _type_key = "arith.IterMapExpr";
72-
static constexpr const uint32_t _type_child_slots = 3;
72+
static constexpr const uint32_t _type_child_slots = 2;
7373
TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode);
7474
};
7575

include/tvm/ir/expr.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class BaseExprNode : public Object {
5858
static constexpr const char* _type_key = "BaseExpr";
5959
static constexpr const bool _type_has_method_sequal_reduce = true;
6060
static constexpr const bool _type_has_method_shash_reduce = true;
61-
static constexpr const uint32_t _type_child_slots = 62;
61+
static constexpr const uint32_t _type_child_slots = 64;
6262
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
6363
};
6464

@@ -104,7 +104,7 @@ class PrimExprNode : public BaseExprNode {
104104
TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
105105

106106
static constexpr const char* _type_key = "PrimExpr";
107-
static constexpr const uint32_t _type_child_slots = 38;
107+
static constexpr const uint32_t _type_child_slots = 40;
108108
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
109109
};
110110

include/tvm/ir/type_functor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
9393
TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
9494
TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
9595
TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode);
96+
vtable.Finalize();
9697
return vtable;
9798
}
9899
};

include/tvm/node/functor.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <tvm/runtime/logging.h>
2727
#include <tvm/runtime/object.h>
2828

29+
#include <cstring>
2930
#include <type_traits>
3031
#include <utility>
3132
#include <vector>
@@ -72,6 +73,8 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
7273
using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;
7374
/*! \brief internal function table */
7475
std::vector<FPointer> func_;
76+
/*! \brief start range of func index */
77+
uint32_t begin_type_index_{0};
7578

7679
public:
7780
/*! \brief the result type of this functor */
@@ -83,6 +86,8 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
8386
*/
8487
bool can_dispatch(const ObjectRef& n) const {
8588
uint32_t type_index = n->type_index();
89+
if (type_index < begin_type_index_) return false;
90+
type_index -= begin_type_index_;
8691
return type_index < func_.size() && func_[type_index] != nullptr;
8792
}
8893
/*!
@@ -94,7 +99,7 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
9499
R operator()(const ObjectRef& n, Args... args) const {
95100
ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type "
96101
<< n->GetTypeKey();
97-
return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
102+
return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...);
98103
}
99104
/*!
100105
* \brief set the dispatcher for type TNode
@@ -109,6 +114,7 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
109114
func_.resize(tindex + 1, nullptr);
110115
}
111116
ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set";
117+
ICHECK_EQ(begin_type_index_, 0) << " Cannot call set_dispatch after calling Finalize";
112118
func_[tindex] = f;
113119
return *this;
114120
}
@@ -122,9 +128,29 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
122128
TSelf& clear_dispatch() { // NOLINT(*)
123129
uint32_t tindex = TNode::RuntimeTypeIndex();
124130
ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
131+
ICHECK_EQ(begin_type_index_, 0) << " Cannot call clear_dispatch after calling Finalize";
125132
func_[tindex] = nullptr;
126133
return *this;
127134
}
135+
/*!
136+
* \brief Finalize the functor after calling sequence of set_dispatch
137+
* This function will attempt to find the min type index that is not null
138+
* and optimize the space of the func table so it is more compact
139+
*/
140+
void Finalize() {
141+
ICHECK_EQ(begin_type_index_, 0) << "Can only call Finalize once";
142+
while (begin_type_index_ < func_.size() && func_[begin_type_index_] == nullptr) {
143+
++begin_type_index_;
144+
}
145+
// shift up the function value
146+
size_t new_ftable_size = func_.size() - begin_type_index_;
147+
if (begin_type_index_ != 0) {
148+
std::memmove(func_.data(), func_.data() + begin_type_index_,
149+
new_ftable_size * sizeof(FPointer));
150+
}
151+
func_.resize(new_ftable_size);
152+
func_.shrink_to_fit();
153+
}
128154
};
129155

130156
#define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName

include/tvm/relax/dataflow_pattern.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs);
9191
class DFPatternNode : public Object {
9292
public:
9393
static constexpr const char* _type_key = "DFPatternNode";
94+
static constexpr const uint32_t _type_child_slots = 21;
9495
TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object);
9596
};
9697

@@ -373,6 +374,7 @@ class VarPatternNode : public DFPatternNode {
373374
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); }
374375

375376
static constexpr const char* _type_key = "relax.dpl.VarPattern";
377+
static constexpr const uint32_t _type_child_slots = 1;
376378
TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode);
377379
};
378380

include/tvm/relax/dataflow_pattern_functor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,12 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
135135
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
136136
RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
137137
RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
138-
139138
RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode);
140139
RELAX_DFPATTERN_FUNCTOR_DISPATCH(GlobalVarPatternNode);
141140
RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExternFuncPatternNode);
142141
RELAX_DFPATTERN_FUNCTOR_DISPATCH(PrimArrPatternNode);
143142
RELAX_DFPATTERN_FUNCTOR_DISPATCH(UnorderedTuplePatternNode);
143+
vtable.Finalize();
144144
return vtable;
145145
}
146146
};

include/tvm/relax/expr.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class StructInfoNode : public Object {
119119
static constexpr const char* _type_key = "StructInfo";
120120
static constexpr const bool _type_has_method_sequal_reduce = true;
121121
static constexpr const bool _type_has_method_shash_reduce = true;
122-
static constexpr const uint32_t _type_child_slots = 5;
122+
static constexpr const uint32_t _type_child_slots = 7;
123123
TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object);
124124
};
125125

@@ -416,7 +416,7 @@ class VarNode : public LeafExprNode {
416416
static constexpr const char* _type_key = "relax.expr.Var";
417417
static constexpr const bool _type_has_method_sequal_reduce = true;
418418
static constexpr const bool _type_has_method_shash_reduce = true;
419-
static constexpr const uint32_t _type_child_slots = 2;
419+
static constexpr const uint32_t _type_child_slots = 1;
420420
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode);
421421
};
422422

include/tvm/relax/expr_functor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
176176
RELAX_EXPR_FUNCTOR_DISPATCH(PrimValueNode);
177177
RELAX_EXPR_FUNCTOR_DISPATCH(StringImmNode);
178178
RELAX_EXPR_FUNCTOR_DISPATCH(DataTypeImmNode);
179+
vtable.Finalize();
179180
return vtable;
180181
}
181182
};

include/tvm/relax/struct_info_functor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class StructInfoFunctor<R(const StructInfo& n, Args...)> {
108108
TVM_STRUCT_INFO_FUNCTOR_DISPATCH(distributed::DTensorStructInfoNode);
109109
TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TupleStructInfoNode);
110110
TVM_STRUCT_INFO_FUNCTOR_DISPATCH(FuncStructInfoNode);
111+
vtable.Finalize();
111112
return vtable;
112113
}
113114
};

include/tvm/tir/expr_functor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
193193
IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
194194
IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
195195
IR_EXPR_FUNCTOR_DISPATCH(AnyNode);
196+
vtable.Finalize();
196197
return vtable;
197198
}
198199
};

0 commit comments

Comments
 (0)