2626#include < tvm/runtime/logging.h>
2727#include < tvm/runtime/object.h>
2828
29+ #include < cstring>
2930#include < type_traits>
3031#include < utility>
3132#include < vector>
@@ -72,6 +73,8 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
7273 using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;
7374 /* ! \brief internal function table */
7475 std::vector<FPointer> func_;
76+ /* ! \brief start range of func index */
77+ uint32_t begin_type_index_{0 };
7578
7679 public:
7780 /* ! \brief the result type of this functor */
@@ -83,6 +86,8 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
8386 */
8487 bool can_dispatch (const ObjectRef& n) const {
8588 uint32_t type_index = n->type_index ();
89+ if (type_index < begin_type_index_) return false ;
90+ type_index -= begin_type_index_;
8691 return type_index < func_.size () && func_[type_index] != nullptr ;
8792 }
8893 /* !
@@ -94,7 +99,7 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
9499 R operator ()(const ObjectRef& n, Args... args) const {
95100 ICHECK (can_dispatch (n)) << " NodeFunctor calls un-registered function on type "
96101 << n->GetTypeKey ();
97- return (*func_[n->type_index ()])(n, std::forward<Args>(args)...);
102+ return (*func_[n->type_index () - begin_type_index_ ])(n, std::forward<Args>(args)...);
98103 }
99104 /* !
100105 * \brief set the dispatcher for type TNode
@@ -109,6 +114,7 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
109114 func_.resize (tindex + 1 , nullptr );
110115 }
111116 ICHECK (func_[tindex] == nullptr ) << " Dispatch for " << TNode::_type_key << " is already set" ;
117+ ICHECK_EQ (begin_type_index_, 0 ) << " Cannot call set_dispatch after calling Finalize" ;
112118 func_[tindex] = f;
113119 return *this ;
114120 }
@@ -122,9 +128,29 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
122128 TSelf& clear_dispatch () { // NOLINT(*)
123129 uint32_t tindex = TNode::RuntimeTypeIndex ();
124130 ICHECK_LT (tindex, func_.size ()) << " clear_dispatch: index out of range" ;
131+ ICHECK_EQ (begin_type_index_, 0 ) << " Cannot call clear_dispatch after calling Finalize" ;
125132 func_[tindex] = nullptr ;
126133 return *this ;
127134 }
135+ /* !
136+ * \brief Finalize the functor after calling sequence of set_dispatch
137+ * This function will attempt to find the min type index that is not null
138+ * and optimize the space of the func table so it is more compact
139+ */
140+ void Finalize () {
141+ ICHECK_EQ (begin_type_index_, 0 ) << " Can only call Finalize once" ;
142+ while (begin_type_index_ < func_.size () && func_[begin_type_index_] == nullptr ) {
143+ ++begin_type_index_;
144+ }
145+ // shift up the function value
146+ size_t new_ftable_size = func_.size () - begin_type_index_;
147+ if (begin_type_index_ != 0 ) {
148+ std::memmove (func_.data (), func_.data () + begin_type_index_,
149+ new_ftable_size * sizeof (FPointer));
150+ }
151+ func_.resize (new_ftable_size);
152+ func_.shrink_to_fit ();
153+ }
128154};
129155
130156#define TVM_REG_FUNC_VAR_DEF (ClsName ) static TVM_ATTRIBUTE_UNUSED auto & __make_functor##_##ClsName
0 commit comments