Skip to content
This repository has been archived by the owner on Jun 20, 2024. It is now read-only.

Fix for negative TTLs in WeaveDNS #501

Merged
merged 1 commit into from
Apr 13, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 16 additions & 68 deletions nameserver/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ const (
defPendingTimeout int = 5 // timeout for a resolution
)

const nullTTL = 0 // a null TTL

type entryStatus uint8

const (
Expand Down Expand Up @@ -61,24 +63,15 @@ type cacheEntry struct {
validUntil time.Time // obtained from the reply and stored here for convenience/speed
putTime time.Time

waitChan chan struct{}

index int // for fast lookups in the heap
}

func newCacheEntry(question *dns.Question, reply *dns.Msg, status entryStatus, flags uint8, now time.Time) *cacheEntry {
func newCacheEntry(question *dns.Question, now time.Time) *cacheEntry {
e := &cacheEntry{
Status: status,
Flags: flags,
question: *question,
index: -1,
}

if e.Status == stPending {
e.validUntil = now.Add(time.Duration(defPendingTimeout) * time.Second)
e.waitChan = make(chan struct{})
} else {
e.setReply(reply, flags, now)
Status: stPending,
validUntil: now.Add(time.Second * time.Duration(defPendingTimeout)),
question: *question,
index: -1,
}

return e
Expand Down Expand Up @@ -133,9 +126,7 @@ func (e cacheEntry) hasExpired(now time.Time) bool {

// set the reply for the entry
// returns True if the entry has changed the validUntil time
func (e *cacheEntry) setReply(reply *dns.Msg, flags uint8, now time.Time) bool {
shouldNotify := (e.Status == stPending)

func (e *cacheEntry) setReply(reply *dns.Msg, ttl int, flags uint8, now time.Time) bool {
var prevValidUntil time.Time
if e.Status == stResolved {
prevValidUntil = e.validUntil
Expand All @@ -145,10 +136,9 @@ func (e *cacheEntry) setReply(reply *dns.Msg, flags uint8, now time.Time) bool {
e.Flags = flags
e.putTime = now

if e.Flags&CacheNoLocalReplies != 0 {
// use a fixed timeout for negative local resolutions
e.validUntil = now.Add(time.Second * time.Duration(negLocalTTL))
} else {
if ttl != nullTTL {
e.validUntil = now.Add(time.Second * time.Duration(ttl))
} else if reply != nil {
// calculate the validUntil from the reply TTL
var minTTL uint32 = math.MaxUint32
for _, rr := range reply.Answer {
Expand All @@ -165,37 +155,9 @@ func (e *cacheEntry) setReply(reply *dns.Msg, flags uint8, now time.Time) bool {
e.ReplyLen = reply.Len()
}

if shouldNotify {
close(e.waitChan) // notify all the waiters by closing the channel
}

return (prevValidUntil != e.validUntil)
}

// wait until a valid reply is set in the cache
func (e *cacheEntry) waitReply(request *dns.Msg, timeout time.Duration, maxLen int, now time.Time) (*dns.Msg, error) {
if e.Status == stResolved {
return e.getReply(request, maxLen, now)
}

if timeout > 0 {
select {
case <-e.waitChan:
return e.getReply(request, maxLen, now)
case <-time.After(time.Second * timeout):
return nil, errTimeout
}
}

return nil, errCouldNotResolve
}

func (e *cacheEntry) close() {
if e.Status == stPending {
close(e.waitChan)
}
}

//////////////////////////////////////////////////////////////////////////////////////

// An entriesPtrHeap is a min-heap of cache entries.
Expand Down Expand Up @@ -280,27 +242,26 @@ func (c *Cache) Purge(now time.Time) {
}

// Add adds a reply to the cache.
func (c *Cache) Put(request *dns.Msg, reply *dns.Msg, flags uint8, now time.Time) int {
func (c *Cache) Put(request *dns.Msg, reply *dns.Msg, ttl int, flags uint8, now time.Time) int {
c.lock.Lock()
defer c.lock.Unlock()

question := request.Question[0]
key := cacheKey(question)
ent, found := c.entries[key]
if found {
Debug.Printf("[cache msgid %d] replacing response in cache", request.MsgHdr.Id)
updated := ent.setReply(reply, flags, now)
updated := ent.setReply(reply, ttl, flags, now)
if updated {
heap.Fix(&c.entriesH, ent.index)
}
} else {
// If we will add a new item and the capacity has been exceeded, make some room...
if len(c.entriesH) >= c.Capacity {
lowestEntry := heap.Pop(&c.entriesH).(*cacheEntry)
lowestEntry.close()
delete(c.entries, cacheKey(lowestEntry.question))
}
ent = newCacheEntry(&question, reply, stResolved, flags, now)
ent = newCacheEntry(&question, now)
ent.setReply(reply, ttl, flags, now)
heap.Push(&c.entriesH, ent)
c.entries[key] = ent
}
Expand Down Expand Up @@ -329,20 +290,7 @@ func (c *Cache) Get(request *dns.Msg, maxLen int, now time.Time) (reply *dns.Msg
} else {
// we are the first asking for this name: create an entry with no reply... the caller must wait
Debug.Printf("[cache msgid %d] addind in pending state", request.MsgHdr.Id)
c.entries[key] = newCacheEntry(&question, nil, stPending, 0, now)
}
return
}

// Wait for a reply for a question in the cache
// Notice that the caller could Get() and then Wait() for a question, but the corresponding cache
// entry could have been removed in between. In that case, the caller should retry the query (and
// the user should increase the cache size!)
func (c *Cache) Wait(request *dns.Msg, timeout time.Duration, maxLen int, now time.Time) (reply *dns.Msg, err error) {
// do not try to lock the cache: otherwise, no one else could `Put()` the reply
question := request.Question[0]
if entry, found := c.entries[cacheKey(question)]; found {
reply, err = entry.waitReply(request, timeout, maxLen, now)
c.entries[key] = newCacheEntry(&question, now)
}
return
}
Expand Down
76 changes: 12 additions & 64 deletions nameserver/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestCacheLength(t *testing.T) {
reply := makeAddressReply(questionMsg, question, ips)
reply.Answer[0].Header().Ttl = uint32(i)

l.Put(questionMsg, reply, 0, insTime)
l.Put(questionMsg, reply, 0, 0, insTime)
}

wt.AssertEqualInt(t, l.Len(), cacheLen, "cache length")
Expand Down Expand Up @@ -69,20 +69,20 @@ func TestCacheEntries(t *testing.T) {
resp, err := l.Get(questionMsg, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
if resp != nil {
t.Logf("Got '%s'", resp)
t.Logf("Got\n%s", resp)
t.Fatalf("ERROR: Did not expect a reponse from Get() yet")
}
t.Logf("Trying to get it again")
resp, err = l.Get(questionMsg, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
if resp != nil {
t.Logf("Got '%s'", resp)
t.Logf("Got\n%s", resp)
t.Fatalf("ERROR: Did not expect a reponse from Get() yet")
}

t.Logf("Inserting the reply")
reply1 := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.1.1")})
l.Put(questionMsg, reply1, 0, time.Now())
l.Put(questionMsg, reply1, nullTTL, 0, time.Now())

timeGet1 := time.Now()
t.Logf("Checking we can Get() the reply now")
Expand All @@ -93,13 +93,6 @@ func TestCacheEntries(t *testing.T) {
wt.AssertType(t, resp.Answer[0], (*dns.A)(nil), "DNS record")
ttlGet1 := resp.Answer[0].Header().Ttl

t.Logf("Checking a Wait() with timeout=0 gets the same result")
resp, err = l.Wait(questionMsg, time.Duration(0)*time.Second, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
wt.AssertTrue(t, resp != nil, "reponse from a Wait(timeout=0)")
t.Logf("Received '%s'", resp.Answer[0])
wt.AssertType(t, resp.Answer[0], (*dns.A)(nil), "DNS record")

timeGet2 := timeGet1.Add(time.Duration(1) * time.Second)
t.Logf("Checking that a second Get(), after 1 second, gets the same result, but with reduced TTL")
resp, err = l.Get(questionMsg, minUDPSize, timeGet2)
Expand All @@ -115,13 +108,13 @@ func TestCacheEntries(t *testing.T) {
resp, err = l.Get(questionMsg, minUDPSize, timeGet3)
wt.AssertNoErr(t, err)
if resp != nil {
t.Logf("Got '%s'", resp)
t.Logf("Got\n%s", resp)
t.Fatalf("ERROR: Did NOT expect a reponse from the second Get()")
}

t.Logf("Checking that an Remove() results in Get() returning nothing")
replyTemp := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.9.9")})
l.Put(questionMsg, replyTemp, 0, time.Now())
l.Put(questionMsg, replyTemp, nullTTL, 0, time.Now())
lenBefore := l.Len()
l.Remove(question)
wt.AssertEqualInt(t, l.Len(), lenBefore-1, "cache length")
Expand All @@ -135,10 +128,10 @@ func TestCacheEntries(t *testing.T) {
t.Logf("Inserting a two replies for the same query")
timePut2 := time.Now()
reply2 := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.1.2")})
l.Put(questionMsg, reply2, 0, timePut2)
l.Put(questionMsg, reply2, nullTTL, 0, timePut2)
timePut3 := timePut2.Add(time.Duration(1) * time.Second)
reply3 := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.1.3")})
l.Put(questionMsg, reply3, 0, timePut3)
l.Put(questionMsg, reply3, nullTTL, 0, timePut3)

t.Logf("Checking we get the last one...")
resp, err = l.Get(questionMsg, minUDPSize, timePut3)
Expand All @@ -162,7 +155,7 @@ func TestCacheEntries(t *testing.T) {
resp, err = l.Get(questionMsg, minUDPSize, timePut3.Add(time.Duration(localTTL)*time.Second))
wt.AssertNoErr(t, err)
if resp != nil {
t.Logf("Received '%s'", resp.Answer[0])
t.Logf("Got\n%s", resp.Answer[0])
t.Fatalf("ERROR: Did NOT expect a reponse from the Get()")
}
wt.AssertEqualInt(t, l.Len(), lenBefore-1, "cache length (after getting an expired entry)")
Expand All @@ -180,13 +173,10 @@ func TestCacheEntries(t *testing.T) {
t.Logf("Checking that an Remove() between Get() and Put() does not break things")
replyTemp2 := makeAddressReply(questionMsg2, question2, []net.IP{net.ParseIP("10.0.9.9")})
l.Remove(question2)
l.Put(questionMsg2, replyTemp2, 0, time.Now())
l.Put(questionMsg2, replyTemp2, nullTTL, 0, time.Now())
resp, err = l.Get(questionMsg2, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
wt.AssertNotNil(t, resp, "reponse from Get()")
resp, err = l.Wait(questionMsg2, time.Duration(0)*time.Second, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
wt.AssertNotNil(t, resp, "reponse from Get()")

questionMsg3 := new(dns.Msg)
questionMsg3.SetQuestion("some.other.name", dns.TypeA)
Expand All @@ -195,7 +185,7 @@ func TestCacheEntries(t *testing.T) {

t.Logf("Checking that a entry with CacheNoLocalReplies return an error")
timePut3 = time.Now()
l.Put(questionMsg3, nil, CacheNoLocalReplies, timePut3)
l.Put(questionMsg3, nil, nullTTL, CacheNoLocalReplies, timePut3)
resp, err = l.Get(questionMsg3, minUDPSize, timePut3)
wt.AssertNil(t, resp, "Get() response with CacheNoLocalReplies")
wt.AssertNotNil(t, err, "Get() error with CacheNoLocalReplies")
Expand All @@ -208,51 +198,9 @@ func TestCacheEntries(t *testing.T) {

l.Remove(question3)
t.Logf("Checking that Put&Get with CacheNoLocalReplies with a Remove in the middle returns nothing")
l.Put(questionMsg3, nil, CacheNoLocalReplies, time.Now())
l.Put(questionMsg3, nil, nullTTL, CacheNoLocalReplies, time.Now())
l.Remove(question3)
resp, err = l.Get(questionMsg3, minUDPSize, time.Now())
wt.AssertNil(t, resp, "Get() reponse with CacheNoLocalReplies")
wt.AssertNil(t, err, "Get() error with CacheNoLocalReplies")
}

// Check that waiters are unblocked when the name they are waiting for is inserted
func TestCacheBlockingOps(t *testing.T) {
InitDefaultLogging(true)

const cacheLen = 256

l, err := NewCache(cacheLen)
wt.AssertNoErr(t, err)

requests := []*dns.Msg{}

t.Logf("Starting 256 queries that will block...")
for i := 0; i < cacheLen; i++ {
questionName := fmt.Sprintf("name%d", i)
questionMsg := new(dns.Msg)
questionMsg.SetQuestion(questionName, dns.TypeA)
questionMsg.RecursionDesired = true

requests = append(requests, questionMsg)

go func(request *dns.Msg) {
t.Logf("Querying about %s...", request.Question[0].Name)
_, err := l.Get(request, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
t.Logf("Waiting for %s...", request.Question[0].Name)
r, err := l.Wait(request, 1*time.Second, minUDPSize, time.Now())
t.Logf("Obtained response for %s:\n%s", request.Question[0].Name, r)
wt.AssertNoErr(t, err)
}(questionMsg)
}

// insert the IPs for those names
for i, requestMsg := range requests {
ip := net.ParseIP(fmt.Sprintf("10.0.1.%d", i))
ips := []net.IP{ip}
reply := makeAddressReply(requestMsg, &requestMsg.Question[0], ips)

t.Logf("Inserting response for %s...", requestMsg.Question[0].Name)
l.Put(requestMsg, reply, 0, time.Now())
}
}
Loading