From 69fc6355f8369b3047c1b719b573d031b736fa88 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 4 Jun 2024 16:16:38 -0700 Subject: [PATCH 1/7] added unit test --- src/internal/cachehash/cachehash_test.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/internal/cachehash/cachehash_test.go b/src/internal/cachehash/cachehash_test.go index b10b8013..1ec7280c 100644 --- a/src/internal/cachehash/cachehash_test.go +++ b/src/internal/cachehash/cachehash_test.go @@ -16,6 +16,7 @@ package cachehash import ( "fmt" + "github.com/stretchr/testify/assert" "testing" ) @@ -113,3 +114,21 @@ func TestEject(t *testing.T) { t.Error("Ejected element not removed from hash") } } + +func TestAddExistingBumpsToFront(t *testing.T) { + ch := new(CacheHash) + ch.Init(5) + firstValueKey1 := "value1" + secondValueKey1 := "newValue1" + ch.Add("key1", firstValueKey1) + ch.Add("key2", "value2") + ch.Add("key3", "value3") + ch.Add("key1", secondValueKey1) + k, v := ch.First() + assert.Equal(t, "key1", k, "key1 should be bumped to front since it was just added") + assert.Equal(t, secondValueKey1, v, "add existing should update value") + + k, v = ch.Last() + assert.Equal(t, "key2", k, "key2 should be last") + assert.Equal(t, "value2", v, "key2 should have value: value2") +} From d0e3ac940dcdbe3fceb6ebf1668e386d186077ed Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 4 Jun 2024 16:18:28 -0700 Subject: [PATCH 2/7] change add cache method to upsert --- src/internal/cachehash/cachehash.go | 2 +- src/internal/cachehash/cachehash_test.go | 32 +++++++++++----------- src/internal/cachehash/shardedcachehash.go | 2 +- src/modules/mxlookup/mx_lookup.go | 6 ++-- src/zdns/cache.go | 3 +- 5 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/internal/cachehash/cachehash.go b/src/internal/cachehash/cachehash.go index 6d3379b2..7978ad2f 100644 --- a/src/internal/cachehash/cachehash.go +++ b/src/internal/cachehash/cachehash.go @@ -55,7 +55,7 @@ func (c *CacheHash) Eject() { c.len-- } -func (c *CacheHash) Add(k interface{}, v interface{}) bool { +func (c *CacheHash) Upsert(k interface{}, v interface{}) bool { e, ok := c.h[k] if ok { kv := e.Value.(keyValue) diff --git a/src/internal/cachehash/cachehash_test.go b/src/internal/cachehash/cachehash_test.go index 1ec7280c..303c46cc 100644 --- a/src/internal/cachehash/cachehash_test.go +++ b/src/internal/cachehash/cachehash_test.go @@ -23,7 +23,7 @@ import ( func TestAddOne(t *testing.T) { ch := new(CacheHash) ch.Init(5) - ch.Add("key1", "value1") + ch.Upsert("key1", "value1") if ch.Len() != 1 { t.Error("unable to add any elements") } @@ -39,8 +39,8 @@ func TestAddOne(t *testing.T) { func TestFirstLastSetProperly(t *testing.T) { ch := new(CacheHash) ch.Init(5) - ch.Add("key1", "value1") - ch.Add("key2", "value2") + ch.Upsert("key1", "value1") + ch.Upsert("key2", "value2") if ch.Len() != 2 { t.Error("unable to add multiple elements") } @@ -55,9 +55,9 @@ func TestFirstLastSetProperly(t *testing.T) { func TestDelete(t *testing.T) { ch := new(CacheHash) ch.Init(5) - ch.Add("key1", "value1") - ch.Add("key2", "value2") - ch.Add("key3", "value3") + ch.Upsert("key1", "value1") + ch.Upsert("key2", "value2") + ch.Upsert("key3", "value3") if ch.Len() != 3 { t.Error("unable to add multiple elements") } @@ -84,8 +84,8 @@ func TestDelete(t *testing.T) { func TestMoveFront(t *testing.T) { ch := new(CacheHash) ch.Init(5) - ch.Add("key1", "value1") - ch.Add("key2", "value2") + ch.Upsert("key1", "value1") + ch.Upsert("key2", "value2") ch.Get("key1") if k, v := ch.First(); k != "key1" || v != "value1" { t.Error("first and last not set on add") @@ -98,9 +98,9 @@ func TestMoveFront(t *testing.T) { func TestEject(t *testing.T) { ch := new(CacheHash) ch.Init(2) - ch.Add("key1", "value1") - ch.Add("key2", "value2") - ch.Add("key3", "value3") + ch.Upsert("key1", "value1") + ch.Upsert("key2", "value2") + ch.Upsert("key3", "value3") if ch.Len() != 2 { t.Error("length not respected") } @@ -115,15 +115,15 @@ func TestEject(t *testing.T) { } } -func TestAddExistingBumpsToFront(t *testing.T) { +func TestUpsertExistingBumpsToFront(t *testing.T) { ch := new(CacheHash) ch.Init(5) firstValueKey1 := "value1" secondValueKey1 := "newValue1" - ch.Add("key1", firstValueKey1) - ch.Add("key2", "value2") - ch.Add("key3", "value3") - ch.Add("key1", secondValueKey1) + ch.Upsert("key1", firstValueKey1) + ch.Upsert("key2", "value2") + ch.Upsert("key3", "value3") + ch.Upsert("key1", secondValueKey1) k, v := ch.First() assert.Equal(t, "key1", k, "key1 should be bumped to front since it was just added") assert.Equal(t, secondValueKey1, v, "add existing should update value") diff --git a/src/internal/cachehash/shardedcachehash.go b/src/internal/cachehash/shardedcachehash.go index 66b7653b..cd19cab2 100644 --- a/src/internal/cachehash/shardedcachehash.go +++ b/src/internal/cachehash/shardedcachehash.go @@ -43,7 +43,7 @@ func (c *ShardedCacheHash) getShard(k interface{}) *CacheHash { } func (c *ShardedCacheHash) Add(k interface{}, v interface{}) bool { - return c.getShard(k).Add(k, v) + return c.getShard(k).Upsert(k, v) } func (c *ShardedCacheHash) Get(k interface{}) (interface{}, bool) { diff --git a/src/modules/mxlookup/mx_lookup.go b/src/modules/mxlookup/mx_lookup.go index b391d38b..e07d0062 100644 --- a/src/modules/mxlookup/mx_lookup.go +++ b/src/modules/mxlookup/mx_lookup.go @@ -17,11 +17,13 @@ import ( "github.com/pkg/errors" "github.com/spf13/pflag" "github.com/zmap/dns" - "github.com/zmap/zdns/src/cli" "strings" "sync" + "github.com/zmap/zdns/src/cli" + log "github.com/sirupsen/logrus" + "github.com/zmap/zdns/src/internal/cachehash" "github.com/zmap/zdns/src/zdns" ) @@ -115,7 +117,7 @@ func (mxMod *MXLookupModule) lookupIPs(r *zdns.Resolver, name, nameServer string retv.IPv6Addresses = result.IPv6Addresses } mxMod.CHmu.Lock() - mxMod.CacheHash.Add(name, retv) + mxMod.CacheHash.Upsert(name, retv) mxMod.CHmu.Unlock() return retv, trace } diff --git a/src/zdns/cache.go b/src/zdns/cache.go index b7d19ede..326c71a2 100644 --- a/src/zdns/cache.go +++ b/src/zdns/cache.go @@ -19,6 +19,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/zmap/dns" + "github.com/zmap/zdns/src/internal/cachehash" ) @@ -82,7 +83,7 @@ func (s *Cache) AddCachedAnswer(answer interface{}, depth int) { ExpiresAt: expiresAt} ca.Answers[a] = ta s.IterativeCache.Add(q, ca) - s.VerboseLog(depth+1, "Add cached answer ", q, " ", ca) + s.VerboseLog(depth+1, "Upsert cached answer ", q, " ", ca) } func (s *Cache) GetCachedResult(q Question, isAuthCheck bool, depth int) (SingleQueryResult, bool) { From 84ea3eaa1458f2db21e453a7bc693ef7c3630aae Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 4 Jun 2024 16:18:57 -0700 Subject: [PATCH 3/7] simplify if statement --- src/internal/cachehash/cachehash.go | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/internal/cachehash/cachehash.go b/src/internal/cachehash/cachehash.go index 7978ad2f..769ab188 100644 --- a/src/internal/cachehash/cachehash.go +++ b/src/internal/cachehash/cachehash.go @@ -58,22 +58,20 @@ func (c *CacheHash) Eject() { func (c *CacheHash) Upsert(k interface{}, v interface{}) bool { e, ok := c.h[k] if ok { - kv := e.Value.(keyValue) - kv.Key = k - kv.Value = v + e.Value = v c.l.MoveToFront(e) - } else { - if c.len >= c.maxLen { - c.Eject() - } - var kv keyValue - kv.Key = k - kv.Value = v - e = c.l.PushFront(kv) - c.len++ - c.h[k] = e + return true } - return ok + if c.len >= c.maxLen { + c.Eject() + } + var kv keyValue + kv.Key = k + kv.Value = v + e = c.l.PushFront(kv) + c.len++ + c.h[k] = e + return false } func (c *CacheHash) First() (interface{}, interface{}) { From ed17de5355bdacd5e32418269db295d34ca459b6 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 4 Jun 2024 16:19:49 -0700 Subject: [PATCH 4/7] added func comment --- src/internal/cachehash/cachehash.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/internal/cachehash/cachehash.go b/src/internal/cachehash/cachehash.go index 769ab188..5f5547eb 100644 --- a/src/internal/cachehash/cachehash.go +++ b/src/internal/cachehash/cachehash.go @@ -55,6 +55,11 @@ func (c *CacheHash) Eject() { c.len-- } +// Upsert inserts a new key-value pair into the cache. +// If the key already exists, the value is updated and the key is moved to the front of the list. +// If the key does not exist in the cache, the key-value pair is added to the front of the list. +// If the cache is full, the oldest key-value pair is removed. +// Returns whether the key already existed in the cache. func (c *CacheHash) Upsert(k interface{}, v interface{}) bool { e, ok := c.h[k] if ok { From 3b95884f8bec1c4f1b2f89b4307425d447d0d422 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 4 Jun 2024 16:50:10 -0700 Subject: [PATCH 5/7] fixed tests and added comments --- src/internal/cachehash/cachehash.go | 50 ++++++++++++++++-------- src/internal/cachehash/cachehash_test.go | 8 ++-- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/src/internal/cachehash/cachehash.go b/src/internal/cachehash/cachehash.go index 5f5547eb..86b4bd39 100644 --- a/src/internal/cachehash/cachehash.go +++ b/src/internal/cachehash/cachehash.go @@ -19,6 +19,9 @@ import ( "sync" ) +// CacheHash is an LRU cache implemented with a hash map and a doubly linked list. +// This allows for O(1) insertions, deletions, and lookups and ensures the most recently accesssed elements are +// persisted in the cache. type CacheHash struct { sync.Mutex h map[interface{}]*list.Element @@ -33,6 +36,7 @@ type keyValue struct { Value interface{} } +// Init initializes the cache with a maximum length. func (c *CacheHash) Init(maxLen int) { c.l = list.New() c.l = c.l.Init() @@ -41,6 +45,7 @@ func (c *CacheHash) Init(maxLen int) { c.maxLen = maxLen } +// Eject removes the oldest key-value pair from the cache. func (c *CacheHash) Eject() { if c.len == 0 { return @@ -55,31 +60,34 @@ func (c *CacheHash) Eject() { c.len-- } -// Upsert inserts a new key-value pair into the cache. +// Upsert upserts a new key-value pair into the cache. // If the key already exists, the value is updated and the key is moved to the front of the list. // If the key does not exist in the cache, the key-value pair is added to the front of the list. -// If the cache is full, the oldest key-value pair is removed. // Returns whether the key already existed in the cache. func (c *CacheHash) Upsert(k interface{}, v interface{}) bool { e, ok := c.h[k] + var updatedKV keyValue + updatedKV.Key = k + updatedKV.Value = v if ok { - e.Value = v + // update value to have the new value + e.Value = updatedKV c.l.MoveToFront(e) return true } if c.len >= c.maxLen { + // cache is full, remove oldest key-value pair c.Eject() } - var kv keyValue - kv.Key = k - kv.Value = v - e = c.l.PushFront(kv) + e = c.l.PushFront(updatedKV) c.len++ c.h[k] = e return false } -func (c *CacheHash) First() (interface{}, interface{}) { +// First returns the key-value pair at the front of the list. +// Returns nil, nil if the cache is empty. +func (c *CacheHash) First() (k interface{}, v interface{}) { if c.len == 0 { return nil, nil } @@ -88,7 +96,9 @@ func (c *CacheHash) First() (interface{}, interface{}) { return kv.Key, kv.Value } -func (c *CacheHash) Last() (interface{}, interface{}) { +// Last returns the key-value pair at the back of the list. +// Returns nil, nil if the cache is empty. +func (c *CacheHash) Last() (k interface{}, v interface{}) { if c.len == 0 { return nil, nil } @@ -97,30 +107,36 @@ func (c *CacheHash) Last() (interface{}, interface{}) { return kv.Key, kv.Value } -func (c *CacheHash) Get(k interface{}) (interface{}, bool) { +// Get returns the value associated with the key and whether the key was found in the cache. +// v is nil if the key was not found. +func (c *CacheHash) Get(k interface{}) (v interface{}, found bool) { e, ok := c.h[k] if ok { c.l.MoveToFront(e) kv := e.Value.(keyValue) - return kv.Value, ok + return kv.Value, true } - return nil, ok + return nil, false } -func (c *CacheHash) GetNoMove(k interface{}) (interface{}, bool) { +// GetNoMove returns the value associated with the key and whether the key was found in the cache. +func (c *CacheHash) GetNoMove(k interface{}) (v interface{}, found bool) { e, ok := c.h[k] if ok { - return e.Value.(keyValue).Value, ok + return e.Value.(keyValue).Value, true } - return nil, ok + return nil, false } +// Has returns whether the key is in the cache. func (c *CacheHash) Has(k interface{}) bool { _, ok := c.h[k] return ok } -func (c *CacheHash) Delete(k interface{}) (interface{}, bool) { +// Delete removes the key-value pair from the cache and returns the value and whether the key was found. +// v is nil if the key was not found. +func (c *CacheHash) Delete(k interface{}) (v interface{}, found bool) { e, ok := c.h[k] if ok != true { return nil, false @@ -132,10 +148,12 @@ func (c *CacheHash) Delete(k interface{}) (interface{}, bool) { return kv.Value, true } +// Len returns the number of key-value pairs in the cache. func (c *CacheHash) Len() int { return c.len } +// RegisterCB registers a callback function to be called when an element is ejected from the cache. func (c *CacheHash) RegisterCB(newCB func(interface{}, interface{})) { c.ejectCB = newCB } diff --git a/src/internal/cachehash/cachehash_test.go b/src/internal/cachehash/cachehash_test.go index 303c46cc..3e3e625b 100644 --- a/src/internal/cachehash/cachehash_test.go +++ b/src/internal/cachehash/cachehash_test.go @@ -118,15 +118,13 @@ func TestEject(t *testing.T) { func TestUpsertExistingBumpsToFront(t *testing.T) { ch := new(CacheHash) ch.Init(5) - firstValueKey1 := "value1" - secondValueKey1 := "newValue1" - ch.Upsert("key1", firstValueKey1) + ch.Upsert("key1", "value1") ch.Upsert("key2", "value2") ch.Upsert("key3", "value3") - ch.Upsert("key1", secondValueKey1) + ch.Upsert("key1", "newValue1") k, v := ch.First() assert.Equal(t, "key1", k, "key1 should be bumped to front since it was just added") - assert.Equal(t, secondValueKey1, v, "add existing should update value") + assert.Equal(t, "newValue1", v, "add existing should update value") k, v = ch.Last() assert.Equal(t, "key2", k, "key2 should be last") From b30974dedaef02b116f944c20161bf131f8a2213 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Wed, 5 Jun 2024 14:39:21 -0700 Subject: [PATCH 6/7] add another unit test --- src/internal/cachehash/cachehash_test.go | 47 ++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/internal/cachehash/cachehash_test.go b/src/internal/cachehash/cachehash_test.go index 3e3e625b..294ced04 100644 --- a/src/internal/cachehash/cachehash_test.go +++ b/src/internal/cachehash/cachehash_test.go @@ -130,3 +130,50 @@ func TestUpsertExistingBumpsToFront(t *testing.T) { assert.Equal(t, "key2", k, "key2 should be last") assert.Equal(t, "value2", v, "key2 should have value: value2") } + +func TestUpsertWithFullCache(t *testing.T) { + ch := new(CacheHash) + ch.Init(2) + ch.Upsert("key1", "value1") + ch.Upsert("key2", "value2") + + k, v := ch.First() + assert.Equal(t, "key2", k, "First key should be key2") + assert.Equal(t, "value2", v, "First value should be value2") + + k, v = ch.Last() + assert.Equal(t, "key1", k, "Last key should be key1") + assert.Equal(t, "value1", v, "Last value should be value1") + + ch.Upsert("key3", "value3") + + assert.Len(t, ch.h, 2, "Cache should have 2 elements, since it is full and one was evicted") + + k, v = ch.First() + assert.Equal(t, "key3", k, "First key should be key3") + assert.Equal(t, "value3", v, "First value should be value3") + + // key1 should have been evicted since it was the oldest, and key2 should still be in the cache + k, v = ch.Last() + assert.Equal(t, "key2", k, "Last key should be key2") + assert.Equal(t, "value2", v, "Last value should be value2") +} + +func TestGetNoMove(t *testing.T) { + ch := new(CacheHash) + ch.Init(5) + ch.Upsert("key1", "value1") + ch.Upsert("key2", "value2") + ch.GetNoMove("key1") + + k, v := ch.First() + assert.Equal(t, "key2", k, "First key should be key2") + assert.Equal(t, "value2", v, "First value should be value2") + + v, found := ch.GetNoMove("key1") + assert.True(t, found, "key1 should be found") + + k, v = ch.First() + assert.Equal(t, "key2", k, "First key should still be key2 post GetNoMove") + assert.Equal(t, "value2", v, "First value should be value2") +} From 1b4d2d184c0babb27cafef432d0bb7c0d25584e0 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Wed, 5 Jun 2024 14:44:46 -0700 Subject: [PATCH 7/7] added more details to comments --- src/internal/cachehash/cachehash.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/internal/cachehash/cachehash.go b/src/internal/cachehash/cachehash.go index 86b4bd39..6886259f 100644 --- a/src/internal/cachehash/cachehash.go +++ b/src/internal/cachehash/cachehash.go @@ -19,7 +19,8 @@ import ( "sync" ) -// CacheHash is an LRU cache implemented with a hash map and a doubly linked list. +// CacheHash is an LRU cache implemented with a hash map and a doubly linked list. The list stores key-value pairs +// in the order they were accessed, with the most recently accessed key-value pair at the front of the list. // This allows for O(1) insertions, deletions, and lookups and ensures the most recently accesssed elements are // persisted in the cache. type CacheHash struct { @@ -45,7 +46,7 @@ func (c *CacheHash) Init(maxLen int) { c.maxLen = maxLen } -// Eject removes the oldest key-value pair from the cache. +// Eject removes the least-recently used key-value pair from the cache. func (c *CacheHash) Eject() { if c.len == 0 { return @@ -108,6 +109,7 @@ func (c *CacheHash) Last() (k interface{}, v interface{}) { } // Get returns the value associated with the key and whether the key was found in the cache. +// It also moves it to the front of the list. // v is nil if the key was not found. func (c *CacheHash) Get(k interface{}) (v interface{}, found bool) { e, ok := c.h[k]