Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 68 additions & 28 deletions ffi/include/tvm/ffi/container/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ class MapObj : public Object {
uint64_t size_;
/*! \brief number of slots */
uint64_t slots_;
/*!
* \brief Small layout tag mask
* \note The most significant bit is used to indicate the small map layout.
*/
static constexpr uint64_t kSmallTagMask = static_cast<uint64_t>(1) << 63;
/*!
* \brief Check if the map is a small map
* \return True if the map is a small map
*/
bool IsSmallMap() const { return (slots_ & kSmallTagMask) != 0ull; }
/*!
* \brief Optional data deleter when data is allocated separately
* and its deletion is not managed by MapObj::deleter_.
Expand All @@ -242,6 +252,13 @@ class SmallMapObj : public MapObj,
using MapObj::iterator;
using MapObj::KVType;

// Return the number of usable slots for Small layout (mask off tag).
/*!
* \brief Return the number of usable slots for Small layout (mask off tag).
* \return The number of usable slots
*/
uint64_t NumSlots() const { return slots_ & ~kSmallTagMask; }

~SmallMapObj() {
KVType* begin = static_cast<KVType*>(data_);
for (uint64_t index = 0; index < size_; ++index) {
Expand Down Expand Up @@ -310,6 +327,11 @@ class SmallMapObj : public MapObj,
void erase(const iterator& position) { Erase(position.index); }

private:
/*!
* \brief Set the number of slots and attach tags bit.
* \param n The number of slots
*/
void SetSlotsAndSmallLayoutTag(uint64_t n) { slots_ = (n & ~kSmallTagMask) | kSmallTagMask; }
/*!
* \brief Remove a position in SmallMapObj
* \param index The position to be removed
Expand Down Expand Up @@ -344,7 +366,7 @@ class SmallMapObj : public MapObj,
ObjectPtr<SmallMapObj> p = make_inplace_array_object<SmallMapObj, KVType>(n);
p->data_ = p->AddressOf(0);
p->size_ = 0;
p->slots_ = n;
p->SetSlotsAndSmallLayoutTag(n);
return p;
}
/*!
Expand Down Expand Up @@ -386,15 +408,15 @@ class SmallMapObj : public MapObj,
itr->second = kv.second;
return;
}
if (map_node->size_ < map_node->slots_) {
if (map_node->size_ < map_node->NumSlots()) {
KVType* ptr = static_cast<KVType*>(map_node->data_) + map_node->size_;
new (ptr) KVType(std::move(kv));
++map_node->size_;
return;
}
uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize));
uint64_t next_size = std::max(map_node->NumSlots() * 2, uint64_t(kInitSize));
next_size = std::min(next_size, uint64_t(kMaxSize));
TVM_FFI_ICHECK_GT(next_size, map_node->slots_);
TVM_FFI_ICHECK_GT(next_size, map_node->NumSlots());
ObjectPtr<Object> new_map = CreateFromRange(next_size, map_node->begin(), map_node->end());
InsertMaybeReHash(std::move(kv), &new_map);
*map = std::move(new_map);
Expand Down Expand Up @@ -525,6 +547,12 @@ class DenseMapObj : public MapObj {
public:
using MapObj::iterator;

/*!
* \brief Return the number of usable slots for Dense layout (MSB clear => identity).
* \return The number of usable slots
*/
uint64_t NumSlots() const { return slots_; }

/*!
* \brief Destroy the DenseMapObj
*/
Expand Down Expand Up @@ -558,7 +586,7 @@ class DenseMapObj : public MapObj {
*/
void erase(const iterator& position) {
uint64_t index = position.index;
if (position.self != nullptr && index <= this->slots_) {
if (position.self != nullptr && index <= this->NumSlots()) {
Erase(ListNode(index, this));
}
}
Expand Down Expand Up @@ -817,7 +845,7 @@ class DenseMapObj : public MapObj {
}
/*! \brief Clear the container to empty, release all entries and memory acquired */
void Reset() {
uint64_t n_blocks = CalcNumBlocks(this->slots_);
uint64_t n_blocks = CalcNumBlocks(this->NumSlots());
for (uint64_t bi = 0; bi < n_blocks; ++bi) {
uint8_t* meta_ptr = GetBlock(bi)->bytes;
ItemType* data_ptr = reinterpret_cast<ItemType*>(GetBlock(bi)->bytes + kBlockCap);
Expand Down Expand Up @@ -852,6 +880,8 @@ class DenseMapObj : public MapObj {
*/
static ObjectPtr<DenseMapObj> Empty(uint32_t fib_shift, uint64_t n_slots) {
TVM_FFI_ICHECK_GT(n_slots, uint64_t(SmallMapObj::kMaxSize));
// Ensure even slot count (power-of-two expected by callers; this guard
// makes the method robust if a non-even value slips through).
ObjectPtr<DenseMapObj> p = make_object<DenseMapObj>();
uint64_t n_blocks = CalcNumBlocks(n_slots);
Block* block = new Block[n_blocks];
Expand All @@ -860,7 +890,7 @@ class DenseMapObj : public MapObj {
// in another shared-lib that may have different malloc/free behavior
// it will still be safe.
p->data_deleter_ = BlockDeleter;
p->slots_ = n_slots;
p->SetSlotsAndDenseLayoutTag(n_slots);
p->size_ = 0;
p->fib_shift_ = fib_shift;
p->iter_list_head_ = kInvalidIndex;
Expand All @@ -877,13 +907,13 @@ class DenseMapObj : public MapObj {
*/
static ObjectPtr<DenseMapObj> CopyFrom(DenseMapObj* from) {
ObjectPtr<DenseMapObj> p = make_object<DenseMapObj>();
uint64_t n_blocks = CalcNumBlocks(from->slots_);
uint64_t n_blocks = CalcNumBlocks(from->NumSlots());
p->data_ = new Block[n_blocks];
// assign block deleter so even if we take re-alloc data
// in another shared-lib that may have different malloc/free behavior
// it will still be safe.
p->data_deleter_ = BlockDeleter;
p->slots_ = from->slots_;
p->SetSlotsAndDenseLayoutTag(from->NumSlots());
p->size_ = from->size_;
p->fib_shift_ = from->fib_shift_;
p->iter_list_head_ = from->iter_list_head_;
Expand Down Expand Up @@ -919,9 +949,9 @@ class DenseMapObj : public MapObj {
map_node->IterListPushBack(iter);
return;
}
TVM_FFI_ICHECK_GT(map_node->slots_, uint64_t(SmallMapObj::kMaxSize));
TVM_FFI_ICHECK(!map_node->IsSmallMap());
// Otherwise, start rehash
ObjectPtr<Object> p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2);
ObjectPtr<Object> p = Empty(map_node->fib_shift_ - 1, map_node->NumSlots() * 2);

// need to insert in the same order as the original map
for (uint64_t index = map_node->iter_list_head_; index != kInvalidIndex;) {
Expand All @@ -947,7 +977,7 @@ class DenseMapObj : public MapObj {
* \brief Check whether the hash table is full
* \return A boolean indicating whether hash table is full
*/
bool IsFull() const { return size_ + 1 > slots_ * kMaxLoadFactor; }
bool IsFull() const { return size_ + 1 > NumSlots() * kMaxLoadFactor; }
/*!
* \brief Increment the pointer
* \param index The pointer to be incremented
Expand Down Expand Up @@ -1089,7 +1119,7 @@ class DenseMapObj : public MapObj {
}
// the probing will go to next position and round back to stay within the
// correct range of the slots
index = (index + offset) % self->slots_;
index = (index + offset) % self->NumSlots();
block = self->GetBlock(index / kBlockCap);
return true;
}
Expand All @@ -1110,7 +1140,7 @@ class DenseMapObj : public MapObj {
for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) {
// the probing will go to next position and round back to stay within the
// correct range of the slots
ListNode candidate((index + NextProbeLocation(idx)) % self->slots_, self);
ListNode candidate((index + NextProbeLocation(idx)) % self->NumSlots(), self);
if (candidate.IsEmpty()) {
*jump = idx;
*result = candidate;
Expand Down Expand Up @@ -1164,14 +1194,23 @@ class DenseMapObj : public MapObj {
return kNextProbeLocation[index];
}
friend class MapObj;

private:
/*!
* \brief Set the number of slots and attach tags bit.
* \param n The number of slots
*/
void SetSlotsAndDenseLayoutTag(uint64_t n) {
TVM_FFI_ICHECK(((n & kSmallTagMask) == 0ull)) << "DenseMap expects MSB clear";
slots_ = n;
}
};

#define TVM_FFI_DISPATCH_MAP(base, var, body) \
{ \
using TSmall = SmallMapObj*; \
using TDense = DenseMapObj*; \
uint64_t slots = base->slots_; \
if (slots <= SmallMapObj::kMaxSize) { \
if (base->IsSmallMap()) { \
TSmall var = static_cast<TSmall>(base); \
body; \
} else { \
Expand All @@ -1184,8 +1223,7 @@ class DenseMapObj : public MapObj {
{ \
using TSmall = const SmallMapObj*; \
using TDense = const DenseMapObj*; \
uint64_t slots = base->slots_; \
if (slots <= SmallMapObj::kMaxSize) { \
if (base->IsSmallMap()) { \
TSmall var = static_cast<TSmall>(base); \
body; \
} else { \
Expand Down Expand Up @@ -1249,7 +1287,7 @@ inline void MapObj::erase(const MapObj::iterator& position) {
inline ObjectPtr<MapObj> MapObj::Empty() { return SmallMapObj::Empty(); }

inline ObjectPtr<MapObj> MapObj::CopyFrom(MapObj* from) {
if (from->slots_ <= SmallMapObj::kMaxSize) {
if (from->IsSmallMap()) {
return SmallMapObj::CopyFrom(static_cast<SmallMapObj*>(from));
} else {
return DenseMapObj::CopyFrom(static_cast<DenseMapObj*>(from));
Expand Down Expand Up @@ -1288,20 +1326,22 @@ inline ObjectPtr<Object> MapObj::CreateFromRange(IterType first, IterType last)
}

inline void MapObj::InsertMaybeReHash(KVType&& kv, ObjectPtr<Object>* map) {
constexpr uint64_t kSmallMapMaxSize = SmallMapObj::kMaxSize;
MapObj* base = static_cast<MapObj*>(map->get());
#if TVM_FFI_DEBUG_WITH_ABI_CHANGE
base->state_marker++;
#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE
if (base->slots_ < kSmallMapMaxSize) {
SmallMapObj::InsertMaybeReHash(std::move(kv), map);
} else if (base->slots_ == kSmallMapMaxSize) {
if (base->size_ < base->slots_) {
if (base->IsSmallMap()) {
SmallMapObj* sm = static_cast<SmallMapObj*>(base);
if (sm->NumSlots() < SmallMapObj::kMaxSize) {
SmallMapObj::InsertMaybeReHash(std::move(kv), map);
} else {
ObjectPtr<Object> new_map = MapObj::CreateFromRange(base->begin(), base->end());
DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map);
*map = std::move(new_map);
} else if (sm->NumSlots() == SmallMapObj::kMaxSize) {
if (base->size_ < sm->NumSlots()) {
SmallMapObj::InsertMaybeReHash(std::move(kv), map);
} else {
ObjectPtr<Object> new_map = MapObj::CreateFromRange(base->begin(), base->end());
DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map);
*map = std::move(new_map);
}
}
} else {
DenseMapObj::InsertMaybeReHash(std::move(kv), map);
Expand Down
Loading