From d81d1a693c6fa1d5fb1ce903b05faf9d7267158d Mon Sep 17 00:00:00 2001 From: Samana Date: Wed, 26 Nov 2014 22:33:21 -0500 Subject: [PATCH] + Added .Net style Dictionary ( cache friendly hash table ) + Minor VC2013 compliance fixes --- include/avl.h | 2 + include/dictionary.h | 562 ++++++++++++++++++++++++++++++++++++++++ include/hash_multi.h | 4 +- include/prime.h | 2 +- include/suffix_array.h | 3 +- include/suffix_tree.h | 2 + msvc/alg_vs.h | 4 +- src/dictionary_demo.cpp | 49 ++++ 8 files changed, 622 insertions(+), 6 deletions(-) create mode 100644 include/dictionary.h create mode 100644 src/dictionary_demo.cpp diff --git a/include/avl.h b/include/avl.h index b339739c..4b8f00cd 100644 --- a/include/avl.h +++ b/include/avl.h @@ -25,6 +25,8 @@ #include #include #include +#include +#include namespace alg { diff --git a/include/dictionary.h b/include/dictionary.h new file mode 100644 index 00000000..4c38a9cf --- /dev/null +++ b/include/dictionary.h @@ -0,0 +1,562 @@ +/******************************************************************************* +* DANIEL'S ALGORITHM IMPLEMENTAIONS +* +* /\ | _ _ ._ o _|_ |_ ._ _ _ +* /--\ | (_| (_) | | |_ | | | | | _> +* _| +* +* .Net Dictionary Implementation (Cache friendly hash table) +* +******************************************************************************/ + +#pragma once + +#include +#include +#include +#include "hash_code.h" +#include "prime.h" + +namespace alg +{ + +template< typename TKey, typename TValue, typename THash = hash_code > +class Dictionary +{ +public: + struct KeyValuePair + { + TKey Key; + TValue Value; + }; + +private: + struct Entry : public KeyValuePair + { + int32_t HashCode; + int32_t Next; + + Entry() + : HashCode(-1) + , Next(-1) + { + } + + void Reset() + { + HashCode = -1; + Next = -1; + Key = TKey(); + Value = TValue(); + } + }; + +private: + std::vector m_Buckets; + std::vector m_Entries; + int32_t m_Count; + int32_t m_FreeList; + int32_t m_FreeCount; + + friend class Iterator; +public: + + template + class IteratorBase + { + protected: + DictType* Dict; + int32_t Index; + EntryType* Current; + + friend class Dictionary; + + IteratorBase(DictType* dict) + : Dict(dict) + , Index(0) + , Current(nullptr) + { + } + + public: + TIter& operator++() + { + while ((uint32_t) Index < (uint32_t) Dict->m_Count) + { + if (Dict->m_Entries[Index].HashCode >= 0) + { + Current = &Dict->m_Entries[Index]; + Index++; + return *static_cast(this); + } + Index++; + } + + Index = Dict->m_Count + 1; + Current = nullptr; + return *static_cast(this); + } + + TIter operator++(int32_t) + { + TIter tmp = *static_cast(this); + ++(*this); + return tmp; + } + + bool operator == (const TIter& other) const + { + return Dict == other.Dict + && Index == other.Index + && Current == other.Current; + } + + bool operator != (const TIter& other) const + { + return !(*this == other); + } + }; + + class Iterator : public IteratorBase + { + friend class Dictionary; + private: + Iterator(Dictionary* dict) + : IteratorBase(dict) + { + } + public: + KeyValuePair& operator*() const + { + return *Current; + } + }; + + class ConstIterator : public IteratorBase + { + friend class Dictionary; + private: + ConstIterator(const Dictionary* dict) + : IteratorBase(dict) + { + } + public: + const KeyValuePair& operator*() const + { + return *Current; + } + }; + +public: + typedef Iterator iterator; + typedef ConstIterator const_iterator; + +public: + Dictionary(int32_t capacity = 0) + : m_Count(0) + , m_FreeList(-1) + , m_FreeCount(0) + { + _Init(capacity); + } + + ~Dictionary() + { + Clear(); + } + + int32_t Size() const + { + return m_Count - m_FreeCount; + } + + TValue& operator[](const TKey& key) + { + int32_t i = _FindEntry(key); + if (i >= 0) + { + return m_Entries[i].Value; + } + throw MxKeyNotFoundException(); + } + const TValue& operator[](const TKey& key) const + { + int32_t i = _FindEntry(key); + if (i >= 0) + { + return m_Entries[i].Value; + } + throw MxKeyNotFoundException(); + } + + bool TryGetValue(const TKey& key, TValue& outValue) const + { + int32_t i = _FindEntry(key); + if (i >= 0) + { + outValue = m_Entries[i].Value; + return true; + } + else + { + return false; + } + } + + TValue TryGetValueOrDefault(const TKey& key, const TValue& defaultValue) const + { + int32_t i = _FindEntry(key); + if (i >= 0) + { + return m_Entries[i].Value; + } + else + { + return defaultValue; + } + } + + const TValue& TryGetValueRefOrDefault(const TKey& key, const TValue& defaultValue) const + { + int32_t i = _FindEntry(key); + if (i >= 0) + { + return m_Entries[i].Value; + } + else + { + return defaultValue; + } + } + + TValue* TryGetValuePtr(const TKey& key) + { + int32_t i = _FindEntry(key); + if (i >= 0) + { + return &m_Entries[i].Value; + } + else + { + return nullptr; + } + } + + const TValue* TryGetValuePtr(const TKey& key) const + { + int32_t i = _FindEntry(key); + if (i >= 0) + { + return &m_Entries[i].Value; + } + else + { + return nullptr; + } + } + + void AddOrUpdate(const TKey& key, const TValue& value) + { + _Insert(key, value, false); + } + + bool ContainsKey(const TKey& key) const + { + int32_t i = _FindEntry(key); + return i >= 0; + } + + bool Contains(const std::pair& pair) const + { + int32_t i = _FindEntry(pair.first); + return i >= 0 && pair.second == m_Entries[i].Value; + } + + bool Add(const TKey& key, const TValue& value) + { + int32_t i = _FindEntry(key); + if (i >= 0) + { + return false; + } + + return _Insert(key, value, true); + } + + bool Add(const TKey& key, TValue&& value) + { + int32_t i = _FindEntry(key); + if (i >= 0) + { + return false; + } + + return _Insert(key, value, true); + } + + bool Remove(const TKey& key) + { + int32_t hashCode = THash()(key) & 0x7FFFFFFF; + int32_t bucket = hashCode % m_Buckets.size(); + int32_t last = -1; + for (int i = m_Buckets[bucket]; i >= 0; last = i, i = m_Entries[i].Next) + { + if (m_Entries[i].HashCode == hashCode && m_Entries[i].Key == key) + { + if (last < 0) + { + m_Buckets[bucket] = m_Entries[i].Next; + } + else + { + m_Entries[last].Next = m_Entries[i].Next; + } + m_Entries[i].HashCode = -1; + m_Entries[i].Next = m_FreeList; + m_Entries[i].Key = TKey(); + m_Entries[i].Value = TValue(); + m_FreeList = i; + m_FreeCount++; + + return true; + } + } + return false; + } + + void Clear() + { + if (m_Count > 0) + { + memset(m_Buckets.data(), -1, m_Buckets.size() * sizeof(m_Buckets[0])); + for (auto& entry : m_Entries) + { + entry.Reset(); + } + m_FreeList = -1; + m_FreeCount = 0; + m_Count = 0; + } + } + + Iterator Begin() + { + return ++Iterator(this); + } + + ConstIterator Begin() const + { + return CBegin(); + } + + ConstIterator CBegin() const + { + return ++ConstIterator(this); + } + + Iterator End() + { + Iterator ret(this); + ret.Index = m_Count + 1; + ret.Current = nullptr; + return ret; + } + + ConstIterator End() const + { + return CEnd(); + } + + ConstIterator CEnd() const + { + ConstIterator ret(this); + ret.Index = m_Count + 1; + ret.Current = nullptr; + return ret; + } + + //STL style + iterator begin() + { + return Begin(); + } + + const_iterator begin() const + { + return CBegin(); + } + + const_iterator cbegin() const + { + return CBegin(); + } + + iterator end() + { + return End(); + } + + const_iterator end() const + { + return CEnd(); + } + + const_iterator cend() const + { + return CEnd(); + } +private: + int32_t _FindEntry(const TKey& key) const + { + if (m_Buckets.size() > 0) + { + int32_t hashCode = THash()(key) & 0x7FFFFFFF; + for (int32_t i = m_Buckets[hashCode % m_Buckets.size()]; i >= 0; i = m_Entries[i].Next) + { + if (m_Entries[i].HashCode == hashCode && m_Entries[i].Key == key) + { + return i; + } + } + } + return -1; + } + + void _Init(int32_t capacity) + { + int32_t size = GetNextPrime(capacity); + m_Buckets.clear(); + m_Buckets.resize(size, -1); + m_Entries.clear(); + m_Entries.resize(size); + m_FreeList = -1; + } + + template + bool _Insert(const TKey& key, TValueRef value, bool add) + { + if (m_Buckets.size() == 0) + { + _Init(3); + } + + int32_t hashCode = THash()(key) & 0x7FFFFFFF; + int32_t targetBucket = hashCode % m_Buckets.size(); + + for (int32_t i = m_Buckets[targetBucket]; i >= 0; i = m_Entries[i].Next) + { + if (m_Entries[i].HashCode == hashCode && m_Entries[i].Key == key) + { + if (add) + { + return false; + } + m_Entries[i].Value = value; + return true; + } + } + + int32_t index; + if (m_FreeCount > 0) + { + index = m_FreeList; + m_FreeList = m_Entries[index].Next; + m_FreeCount--; + } + else + { + if (m_Count == m_Entries.size()) + { + _Resize(); + targetBucket = hashCode % m_Buckets.size(); + } + index = m_Count; + m_Count++; + } + + m_Entries[index].HashCode = hashCode; + m_Entries[index].Next = m_Buckets[targetBucket]; + m_Entries[index].Key = key; + m_Entries[index].Value = value; + + m_Buckets[targetBucket] = index; + + return true; + } + + void _Resize() + { + _Resize(GetNextPrime(m_Count * 2), false); + } + + void _Resize(int32_t newSize, bool forceNewHashCodes) + { + assert(newSize >= m_Entries.size()); + + m_Buckets.resize(0); + m_Buckets.resize(newSize, -1); + m_Entries.resize(newSize); + + if (forceNewHashCodes) + { + for (int32_t i = 0; i < m_Count; i++) + { + if (m_Entries[i].HashCode != -1) + { + m_Entries[i].HashCode = (THash()(m_Entries[i].Key) & 0x7FFFFFFF); + } + } + } + for (int32_t i = 0; i < m_Count; i++) + { + if (m_Entries[i].HashCode >= 0) + { + int32_t bucket = m_Entries[i].HashCode % newSize; + m_Entries[i].Next = m_Buckets[bucket]; + m_Buckets[bucket] = i; + } + } + } + + + static int GetNextPrime(int n) + { + static const int c_PrimeArraySize = 72; + static const int c_Primes[c_PrimeArraySize] = + { + 3, 7, 11, 17, 23, 29, 37, 47, 59, 71, 89, 107, 131, 163, 197, 239, 293, 353, 431, 521, 631, 761, 919, + 1103, 1327, 1597, 1931, 2333, 2801, 3371, 4049, 4861, 5839, 7013, 8419, 10103, 12143, 14591, + 17519, 21023, 25229, 30293, 36353, 43627, 52361, 62851, 75431, 90523, 108631, 130363, 156437, + 187751, 225307, 270371, 324449, 389357, 467237, 560689, 672827, 807403, 968897, 1162687, 1395263, + 1674319, 2009191, 2411033, 2893249, 3471899, 4166287, 4999559, 5999471, 7199369 + }; + static const int c_HashPrime = 101; + + if (n < 0) + { + return -1; + } + + for (int i = 0; i < c_PrimeArraySize; i++) + { + int prime = c_Primes[i]; + if (prime >= n) + { + return prime; + } + } + + //outside of our predefined table. + //compute the hard way. + for (int i = (n | 1); i < INT32_MAX; i += 2) + { + if (is_prime(i) && ((i - 1) % c_HashPrime != 0)) + { + return i; + } + } + return n; + } +}; + +} \ No newline at end of file diff --git a/include/hash_multi.h b/include/hash_multi.h index bc75cb37..4639e229 100644 --- a/include/hash_multi.h +++ b/include/hash_multi.h @@ -42,7 +42,7 @@ namespace alg { } #ifdef _MSC_VER -#define log2(x) (log(x) / log(2)) +#define log2(x) (log(x) / log(2.0)) #endif /** @@ -50,7 +50,7 @@ namespace alg { */ static MultiHash * multi_hash_init(uint32_t size) { // find prime larger than log2(size) - uint32_t r = ceil(log2(size)); + uint32_t r = ceil(log2((double)size)); int i; for (i = r; ;i++) { if (is_prime(i)) { diff --git a/include/prime.h b/include/prime.h index fd7b73ef..81f772e4 100644 --- a/include/prime.h +++ b/include/prime.h @@ -34,7 +34,7 @@ namespace alg { if (n%2 == 0) return false; - unsigned sqrtn = sqrt(n); + unsigned sqrtn = sqrt((double)n); for (unsigned int i = 3; i <= sqrtn; i+=2) { if (n % i == 0) { return false; diff --git a/include/suffix_array.h b/include/suffix_array.h index 92fd67e9..c078bd0a 100644 --- a/include/suffix_array.h +++ b/include/suffix_array.h @@ -27,6 +27,7 @@ #include #include #include +#include using namespace std; @@ -69,7 +70,7 @@ namespace alg { for(size_t k=0;k>1) + +using namespace alg; +using namespace std::chrono; + +int main(void) { + + Dictionary dict; + + dict.Add(0, 1); + dict.Add(1, 2); + dict.Add(5, 2); + dict.Add(3, 3); + dict.Remove(5); + dict.AddOrUpdate(3, 4); + + for (auto x : dict) + { + printf("%d - %d\n", x.Key, x.Value); + } + + static const uint32_t TEST_LENGTH = 1000000; + Dictionary d(TEST_LENGTH); + HashTable h(TEST_LENGTH); + + auto t0 = high_resolution_clock::now(); + + for (uint32_t i = 0; i < TEST_LENGTH; i++) + { + d.AddOrUpdate(alg::LCG(), alg::LCG()); + } + + auto t1 = high_resolution_clock::now(); + + for (uint32_t i = 0; i < TEST_LENGTH; i++) + { + h[alg::LCG()] = alg::LCG(); + } + + auto t2 = high_resolution_clock::now(); + + auto dt0 = duration_cast(t1 - t0).count(); + auto dt1 = duration_cast(t2 - t1).count(); + + printf("Dictionary: %lld ms, HashTable: %lld ms\n", dt0, dt1); +}