diff --git a/ffi/include/tvm/ffi/container/map.h b/ffi/include/tvm/ffi/container/map.h index 7892e2008530..b1ca4f805edd 100644 --- a/ffi/include/tvm/ffi/container/map.h +++ b/ffi/include/tvm/ffi/container/map.h @@ -221,6 +221,16 @@ class MapObj : public Object { uint64_t size_; /*! \brief number of slots */ uint64_t slots_; + /*! + * \brief Small layout tag mask + * \note The most significant bit is used to indicate the small map layout. + */ + static constexpr uint64_t kSmallTagMask = static_cast(1) << 63; + /*! + * \brief Check if the map is a small map + * \return True if the map is a small map + */ + bool IsSmallMap() const { return (slots_ & kSmallTagMask) != 0ull; } /*! * \brief Optional data deleter when data is allocated separately * and its deletion is not managed by MapObj::deleter_. @@ -242,6 +252,13 @@ class SmallMapObj : public MapObj, using MapObj::iterator; using MapObj::KVType; + // Return the number of usable slots for Small layout (mask off tag). + /*! + * \brief Return the number of usable slots for Small layout (mask off tag). + * \return The number of usable slots + */ + uint64_t NumSlots() const { return slots_ & ~kSmallTagMask; } + ~SmallMapObj() { KVType* begin = static_cast(data_); for (uint64_t index = 0; index < size_; ++index) { @@ -310,6 +327,11 @@ class SmallMapObj : public MapObj, void erase(const iterator& position) { Erase(position.index); } private: + /*! + * \brief Set the number of slots and attach tags bit. + * \param n The number of slots + */ + void SetSlotsAndSmallLayoutTag(uint64_t n) { slots_ = (n & ~kSmallTagMask) | kSmallTagMask; } /*! * \brief Remove a position in SmallMapObj * \param index The position to be removed @@ -344,7 +366,7 @@ class SmallMapObj : public MapObj, ObjectPtr p = make_inplace_array_object(n); p->data_ = p->AddressOf(0); p->size_ = 0; - p->slots_ = n; + p->SetSlotsAndSmallLayoutTag(n); return p; } /*! @@ -386,15 +408,15 @@ class SmallMapObj : public MapObj, itr->second = kv.second; return; } - if (map_node->size_ < map_node->slots_) { + if (map_node->size_ < map_node->NumSlots()) { KVType* ptr = static_cast(map_node->data_) + map_node->size_; new (ptr) KVType(std::move(kv)); ++map_node->size_; return; } - uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize)); + uint64_t next_size = std::max(map_node->NumSlots() * 2, uint64_t(kInitSize)); next_size = std::min(next_size, uint64_t(kMaxSize)); - TVM_FFI_ICHECK_GT(next_size, map_node->slots_); + TVM_FFI_ICHECK_GT(next_size, map_node->NumSlots()); ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); InsertMaybeReHash(std::move(kv), &new_map); *map = std::move(new_map); @@ -525,6 +547,12 @@ class DenseMapObj : public MapObj { public: using MapObj::iterator; + /*! + * \brief Return the number of usable slots for Dense layout (MSB clear => identity). + * \return The number of usable slots + */ + uint64_t NumSlots() const { return slots_; } + /*! * \brief Destroy the DenseMapObj */ @@ -558,7 +586,7 @@ class DenseMapObj : public MapObj { */ void erase(const iterator& position) { uint64_t index = position.index; - if (position.self != nullptr && index <= this->slots_) { + if (position.self != nullptr && index <= this->NumSlots()) { Erase(ListNode(index, this)); } } @@ -817,7 +845,7 @@ class DenseMapObj : public MapObj { } /*! \brief Clear the container to empty, release all entries and memory acquired */ void Reset() { - uint64_t n_blocks = CalcNumBlocks(this->slots_); + uint64_t n_blocks = CalcNumBlocks(this->NumSlots()); for (uint64_t bi = 0; bi < n_blocks; ++bi) { uint8_t* meta_ptr = GetBlock(bi)->bytes; ItemType* data_ptr = reinterpret_cast(GetBlock(bi)->bytes + kBlockCap); @@ -852,6 +880,8 @@ class DenseMapObj : public MapObj { */ static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { TVM_FFI_ICHECK_GT(n_slots, uint64_t(SmallMapObj::kMaxSize)); + // Ensure even slot count (power-of-two expected by callers; this guard + // makes the method robust if a non-even value slips through). ObjectPtr p = make_object(); uint64_t n_blocks = CalcNumBlocks(n_slots); Block* block = new Block[n_blocks]; @@ -860,7 +890,7 @@ class DenseMapObj : public MapObj { // in another shared-lib that may have different malloc/free behavior // it will still be safe. p->data_deleter_ = BlockDeleter; - p->slots_ = n_slots; + p->SetSlotsAndDenseLayoutTag(n_slots); p->size_ = 0; p->fib_shift_ = fib_shift; p->iter_list_head_ = kInvalidIndex; @@ -877,13 +907,13 @@ class DenseMapObj : public MapObj { */ static ObjectPtr CopyFrom(DenseMapObj* from) { ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(from->slots_); + uint64_t n_blocks = CalcNumBlocks(from->NumSlots()); p->data_ = new Block[n_blocks]; // assign block deleter so even if we take re-alloc data // in another shared-lib that may have different malloc/free behavior // it will still be safe. p->data_deleter_ = BlockDeleter; - p->slots_ = from->slots_; + p->SetSlotsAndDenseLayoutTag(from->NumSlots()); p->size_ = from->size_; p->fib_shift_ = from->fib_shift_; p->iter_list_head_ = from->iter_list_head_; @@ -919,9 +949,9 @@ class DenseMapObj : public MapObj { map_node->IterListPushBack(iter); return; } - TVM_FFI_ICHECK_GT(map_node->slots_, uint64_t(SmallMapObj::kMaxSize)); + TVM_FFI_ICHECK(!map_node->IsSmallMap()); // Otherwise, start rehash - ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2); + ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->NumSlots() * 2); // need to insert in the same order as the original map for (uint64_t index = map_node->iter_list_head_; index != kInvalidIndex;) { @@ -947,7 +977,7 @@ class DenseMapObj : public MapObj { * \brief Check whether the hash table is full * \return A boolean indicating whether hash table is full */ - bool IsFull() const { return size_ + 1 > slots_ * kMaxLoadFactor; } + bool IsFull() const { return size_ + 1 > NumSlots() * kMaxLoadFactor; } /*! * \brief Increment the pointer * \param index The pointer to be incremented @@ -1089,7 +1119,7 @@ class DenseMapObj : public MapObj { } // the probing will go to next position and round back to stay within the // correct range of the slots - index = (index + offset) % self->slots_; + index = (index + offset) % self->NumSlots(); block = self->GetBlock(index / kBlockCap); return true; } @@ -1110,7 +1140,7 @@ class DenseMapObj : public MapObj { for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { // the probing will go to next position and round back to stay within the // correct range of the slots - ListNode candidate((index + NextProbeLocation(idx)) % self->slots_, self); + ListNode candidate((index + NextProbeLocation(idx)) % self->NumSlots(), self); if (candidate.IsEmpty()) { *jump = idx; *result = candidate; @@ -1164,14 +1194,23 @@ class DenseMapObj : public MapObj { return kNextProbeLocation[index]; } friend class MapObj; + + private: + /*! + * \brief Set the number of slots and attach tags bit. + * \param n The number of slots + */ + void SetSlotsAndDenseLayoutTag(uint64_t n) { + TVM_FFI_ICHECK(((n & kSmallTagMask) == 0ull)) << "DenseMap expects MSB clear"; + slots_ = n; + } }; #define TVM_FFI_DISPATCH_MAP(base, var, body) \ { \ using TSmall = SmallMapObj*; \ using TDense = DenseMapObj*; \ - uint64_t slots = base->slots_; \ - if (slots <= SmallMapObj::kMaxSize) { \ + if (base->IsSmallMap()) { \ TSmall var = static_cast(base); \ body; \ } else { \ @@ -1184,8 +1223,7 @@ class DenseMapObj : public MapObj { { \ using TSmall = const SmallMapObj*; \ using TDense = const DenseMapObj*; \ - uint64_t slots = base->slots_; \ - if (slots <= SmallMapObj::kMaxSize) { \ + if (base->IsSmallMap()) { \ TSmall var = static_cast(base); \ body; \ } else { \ @@ -1249,7 +1287,7 @@ inline void MapObj::erase(const MapObj::iterator& position) { inline ObjectPtr MapObj::Empty() { return SmallMapObj::Empty(); } inline ObjectPtr MapObj::CopyFrom(MapObj* from) { - if (from->slots_ <= SmallMapObj::kMaxSize) { + if (from->IsSmallMap()) { return SmallMapObj::CopyFrom(static_cast(from)); } else { return DenseMapObj::CopyFrom(static_cast(from)); @@ -1288,20 +1326,22 @@ inline ObjectPtr MapObj::CreateFromRange(IterType first, IterType last) } inline void MapObj::InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { - constexpr uint64_t kSmallMapMaxSize = SmallMapObj::kMaxSize; MapObj* base = static_cast(map->get()); #if TVM_FFI_DEBUG_WITH_ABI_CHANGE base->state_marker++; #endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - if (base->slots_ < kSmallMapMaxSize) { - SmallMapObj::InsertMaybeReHash(std::move(kv), map); - } else if (base->slots_ == kSmallMapMaxSize) { - if (base->size_ < base->slots_) { + if (base->IsSmallMap()) { + SmallMapObj* sm = static_cast(base); + if (sm->NumSlots() < SmallMapObj::kMaxSize) { SmallMapObj::InsertMaybeReHash(std::move(kv), map); - } else { - ObjectPtr new_map = MapObj::CreateFromRange(base->begin(), base->end()); - DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map); - *map = std::move(new_map); + } else if (sm->NumSlots() == SmallMapObj::kMaxSize) { + if (base->size_ < sm->NumSlots()) { + SmallMapObj::InsertMaybeReHash(std::move(kv), map); + } else { + ObjectPtr new_map = MapObj::CreateFromRange(base->begin(), base->end()); + DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map); + *map = std::move(new_map); + } } } else { DenseMapObj::InsertMaybeReHash(std::move(kv), map);