Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in cachehash that wouldn't update a key-value pair if the key was already in the cache #376

Merged
merged 8 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 47 additions & 24 deletions src/internal/cachehash/cachehash.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ import (
"sync"
)

// CacheHash is an LRU cache implemented with a hash map and a doubly linked list. The list stores key-value pairs
// in the order they were accessed, with the most recently accessed key-value pair at the front of the list.
// This allows for O(1) insertions, deletions, and lookups and ensures the most recently accesssed elements are
// persisted in the cache.
type CacheHash struct {
sync.Mutex
h map[interface{}]*list.Element
Expand All @@ -33,6 +37,7 @@ type keyValue struct {
Value interface{}
}

// Init initializes the cache with a maximum length.
func (c *CacheHash) Init(maxLen int) {
c.l = list.New()
c.l = c.l.Init()
Expand All @@ -41,6 +46,7 @@ func (c *CacheHash) Init(maxLen int) {
c.maxLen = maxLen
}

// Eject removes the least-recently used key-value pair from the cache.
func (c *CacheHash) Eject() {
if c.len == 0 {
return
Expand All @@ -55,28 +61,34 @@ func (c *CacheHash) Eject() {
c.len--
}

func (c *CacheHash) Add(k interface{}, v interface{}) bool {
// Upsert upserts a new key-value pair into the cache.
// If the key already exists, the value is updated and the key is moved to the front of the list.
// If the key does not exist in the cache, the key-value pair is added to the front of the list.
// Returns whether the key already existed in the cache.
func (c *CacheHash) Upsert(k interface{}, v interface{}) bool {
e, ok := c.h[k]
var updatedKV keyValue
updatedKV.Key = k
updatedKV.Value = v
if ok {
kv := e.Value.(keyValue)
kv.Key = k
kv.Value = v
// update value to have the new value
e.Value = updatedKV
c.l.MoveToFront(e)
} else {
if c.len >= c.maxLen {
c.Eject()
}
var kv keyValue
kv.Key = k
kv.Value = v
e = c.l.PushFront(kv)
c.len++
c.h[k] = e
return true
}
return ok
if c.len >= c.maxLen {
// cache is full, remove oldest key-value pair
c.Eject()
}
e = c.l.PushFront(updatedKV)
c.len++
c.h[k] = e
return false
}

func (c *CacheHash) First() (interface{}, interface{}) {
// First returns the key-value pair at the front of the list.
// Returns nil, nil if the cache is empty.
func (c *CacheHash) First() (k interface{}, v interface{}) {
if c.len == 0 {
return nil, nil
}
Expand All @@ -85,7 +97,9 @@ func (c *CacheHash) First() (interface{}, interface{}) {
return kv.Key, kv.Value
}

func (c *CacheHash) Last() (interface{}, interface{}) {
// Last returns the key-value pair at the back of the list.
// Returns nil, nil if the cache is empty.
func (c *CacheHash) Last() (k interface{}, v interface{}) {
if c.len == 0 {
return nil, nil
}
Expand All @@ -94,30 +108,37 @@ func (c *CacheHash) Last() (interface{}, interface{}) {
return kv.Key, kv.Value
}

func (c *CacheHash) Get(k interface{}) (interface{}, bool) {
// Get returns the value associated with the key and whether the key was found in the cache.
// It also moves it to the front of the list.
// v is nil if the key was not found.
func (c *CacheHash) Get(k interface{}) (v interface{}, found bool) {
e, ok := c.h[k]
if ok {
c.l.MoveToFront(e)
kv := e.Value.(keyValue)
return kv.Value, ok
return kv.Value, true
}
return nil, ok
return nil, false
}

func (c *CacheHash) GetNoMove(k interface{}) (interface{}, bool) {
// GetNoMove returns the value associated with the key and whether the key was found in the cache.
func (c *CacheHash) GetNoMove(k interface{}) (v interface{}, found bool) {
e, ok := c.h[k]
if ok {
return e.Value.(keyValue).Value, ok
return e.Value.(keyValue).Value, true
}
return nil, ok
return nil, false
}

// Has returns whether the key is in the cache.
func (c *CacheHash) Has(k interface{}) bool {
_, ok := c.h[k]
return ok
}

func (c *CacheHash) Delete(k interface{}) (interface{}, bool) {
// Delete removes the key-value pair from the cache and returns the value and whether the key was found.
// v is nil if the key was not found.
func (c *CacheHash) Delete(k interface{}) (v interface{}, found bool) {
e, ok := c.h[k]
if ok != true {
return nil, false
Expand All @@ -129,10 +150,12 @@ func (c *CacheHash) Delete(k interface{}) (interface{}, bool) {
return kv.Value, true
}

// Len returns the number of key-value pairs in the cache.
func (c *CacheHash) Len() int {
return c.len
}

// RegisterCB registers a callback function to be called when an element is ejected from the cache.
func (c *CacheHash) RegisterCB(newCB func(interface{}, interface{})) {
c.ejectCB = newCB
}
86 changes: 75 additions & 11 deletions src/internal/cachehash/cachehash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ package cachehash

import (
"fmt"
"github.com/stretchr/testify/assert"
"testing"
)

func TestAddOne(t *testing.T) {
ch := new(CacheHash)
ch.Init(5)
ch.Add("key1", "value1")
ch.Upsert("key1", "value1")
if ch.Len() != 1 {
t.Error("unable to add any elements")
}
Expand All @@ -38,8 +39,8 @@ func TestAddOne(t *testing.T) {
func TestFirstLastSetProperly(t *testing.T) {
ch := new(CacheHash)
ch.Init(5)
ch.Add("key1", "value1")
ch.Add("key2", "value2")
ch.Upsert("key1", "value1")
ch.Upsert("key2", "value2")
if ch.Len() != 2 {
t.Error("unable to add multiple elements")
}
Expand All @@ -54,9 +55,9 @@ func TestFirstLastSetProperly(t *testing.T) {
func TestDelete(t *testing.T) {
ch := new(CacheHash)
ch.Init(5)
ch.Add("key1", "value1")
ch.Add("key2", "value2")
ch.Add("key3", "value3")
ch.Upsert("key1", "value1")
ch.Upsert("key2", "value2")
ch.Upsert("key3", "value3")
if ch.Len() != 3 {
t.Error("unable to add multiple elements")
}
Expand All @@ -83,8 +84,8 @@ func TestDelete(t *testing.T) {
func TestMoveFront(t *testing.T) {
ch := new(CacheHash)
ch.Init(5)
ch.Add("key1", "value1")
ch.Add("key2", "value2")
ch.Upsert("key1", "value1")
ch.Upsert("key2", "value2")
ch.Get("key1")
if k, v := ch.First(); k != "key1" || v != "value1" {
t.Error("first and last not set on add")
Expand All @@ -97,9 +98,9 @@ func TestMoveFront(t *testing.T) {
func TestEject(t *testing.T) {
ch := new(CacheHash)
ch.Init(2)
ch.Add("key1", "value1")
ch.Add("key2", "value2")
ch.Add("key3", "value3")
ch.Upsert("key1", "value1")
ch.Upsert("key2", "value2")
ch.Upsert("key3", "value3")
if ch.Len() != 2 {
t.Error("length not respected")
}
Expand All @@ -113,3 +114,66 @@ func TestEject(t *testing.T) {
t.Error("Ejected element not removed from hash")
}
}

func TestUpsertExistingBumpsToFront(t *testing.T) {
ch := new(CacheHash)
ch.Init(5)
ch.Upsert("key1", "value1")
ch.Upsert("key2", "value2")
ch.Upsert("key3", "value3")
ch.Upsert("key1", "newValue1")
k, v := ch.First()
assert.Equal(t, "key1", k, "key1 should be bumped to front since it was just added")
assert.Equal(t, "newValue1", v, "add existing should update value")

k, v = ch.Last()
assert.Equal(t, "key2", k, "key2 should be last")
assert.Equal(t, "value2", v, "key2 should have value: value2")
}

func TestUpsertWithFullCache(t *testing.T) {
ch := new(CacheHash)
ch.Init(2)
ch.Upsert("key1", "value1")
ch.Upsert("key2", "value2")

k, v := ch.First()
assert.Equal(t, "key2", k, "First key should be key2")
assert.Equal(t, "value2", v, "First value should be value2")

k, v = ch.Last()
assert.Equal(t, "key1", k, "Last key should be key1")
assert.Equal(t, "value1", v, "Last value should be value1")

ch.Upsert("key3", "value3")

assert.Len(t, ch.h, 2, "Cache should have 2 elements, since it is full and one was evicted")

k, v = ch.First()
assert.Equal(t, "key3", k, "First key should be key3")
assert.Equal(t, "value3", v, "First value should be value3")

// key1 should have been evicted since it was the oldest, and key2 should still be in the cache
k, v = ch.Last()
assert.Equal(t, "key2", k, "Last key should be key2")
assert.Equal(t, "value2", v, "Last value should be value2")
}

func TestGetNoMove(t *testing.T) {
ch := new(CacheHash)
ch.Init(5)
ch.Upsert("key1", "value1")
ch.Upsert("key2", "value2")
ch.GetNoMove("key1")

k, v := ch.First()
assert.Equal(t, "key2", k, "First key should be key2")
assert.Equal(t, "value2", v, "First value should be value2")

v, found := ch.GetNoMove("key1")
assert.True(t, found, "key1 should be found")

k, v = ch.First()
assert.Equal(t, "key2", k, "First key should still be key2 post GetNoMove")
assert.Equal(t, "value2", v, "First value should be value2")
}
2 changes: 1 addition & 1 deletion src/internal/cachehash/shardedcachehash.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (c *ShardedCacheHash) getShard(k interface{}) *CacheHash {
}

func (c *ShardedCacheHash) Add(k interface{}, v interface{}) bool {
return c.getShard(k).Add(k, v)
return c.getShard(k).Upsert(k, v)
}

func (c *ShardedCacheHash) Get(k interface{}) (interface{}, bool) {
Expand Down
6 changes: 4 additions & 2 deletions src/modules/mxlookup/mx_lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ import (
"github.com/pkg/errors"
"github.com/spf13/pflag"
"github.com/zmap/dns"
"github.com/zmap/zdns/src/cli"
"strings"
"sync"

"github.com/zmap/zdns/src/cli"

log "github.com/sirupsen/logrus"

"github.com/zmap/zdns/src/internal/cachehash"
"github.com/zmap/zdns/src/zdns"
)
Expand Down Expand Up @@ -115,7 +117,7 @@ func (mxMod *MXLookupModule) lookupIPs(r *zdns.Resolver, name, nameServer string
retv.IPv6Addresses = result.IPv6Addresses
}
mxMod.CHmu.Lock()
mxMod.CacheHash.Add(name, retv)
mxMod.CacheHash.Upsert(name, retv)
mxMod.CHmu.Unlock()
return retv, trace
}
Expand Down
3 changes: 2 additions & 1 deletion src/zdns/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
log "github.com/sirupsen/logrus"

"github.com/zmap/dns"

"github.com/zmap/zdns/src/internal/cachehash"
)

Expand Down Expand Up @@ -82,7 +83,7 @@ func (s *Cache) AddCachedAnswer(answer interface{}, depth int) {
ExpiresAt: expiresAt}
ca.Answers[a] = ta
s.IterativeCache.Add(q, ca)
s.VerboseLog(depth+1, "Add cached answer ", q, " ", ca)
s.VerboseLog(depth+1, "Upsert cached answer ", q, " ", ca)
}

func (s *Cache) GetCachedResult(q Question, isAuthCheck bool, depth int) (SingleQueryResult, bool) {
Expand Down