Skip to content

Commit

Permalink
Add IndependentCache option
Browse files Browse the repository at this point in the history
Co-authored-by: 世界 <i@sekai.icu>
Co-authored-by: armv9 <48624112+arm64v8a@users.noreply.github.com>
  • Loading branch information
nekohasekai and arm64v8a committed Apr 26, 2023
1 parent 442595d commit 18b00d0
Showing 1 changed file with 89 additions and 31 deletions.
120 changes: 89 additions & 31 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,39 @@ var (
)

type Client struct {
disableCache bool
disableExpire bool
logger logger.ContextLogger
disableCache bool
disableExpire bool
independentCache bool
logger logger.ContextLogger
cache *cache.LruCache[dns.Question, *dns.Msg]
transportCache *cache.LruCache[transportCacheKey, *dns.Msg]
}

type transportCacheKey struct {
dns.Question
transportName string
}

cache *cache.LruCache[dns.Question, *dns.Msg]
type ClientOptions struct {
DisableCache bool
DisableExpire bool
IndependentCache bool
Logger logger.ContextLogger
}

func NewClient(disableCache bool, disableExpire bool, logger logger.ContextLogger) *Client {
func NewClient(options ClientOptions) *Client {
client := &Client{
disableCache: disableCache,
disableExpire: disableExpire,
logger: logger,
}
if !disableCache {
client.cache = cache.New[dns.Question, *dns.Msg]()
disableCache: options.DisableCache,
disableExpire: options.DisableExpire,
independentCache: options.IndependentCache,
logger: options.Logger,
}
if !client.disableCache {
if !client.independentCache {
client.cache = cache.New[dns.Question, *dns.Msg]()
} else {
client.transportCache = cache.New[transportCacheKey, *dns.Msg]()
}
}
return client
}
Expand Down Expand Up @@ -76,7 +94,7 @@ func (c *Client) exchange(ctx context.Context, transport Transport, message *dns
question := message.Question[0]
disableCache := c.disableCache || DisableCacheFromContext(ctx)
if !disableCache {
response, ttl := c.loadResponse(question)
response, ttl := c.loadResponse(question, transport)
if response != nil {
logCachedResponse(c.logger, ctx, response, ttl)
response.Id = message.Id
Expand Down Expand Up @@ -128,7 +146,7 @@ func (c *Client) exchange(ctx context.Context, transport Transport, message *dns

response.Id = messageId
if !disableCache {
c.storeCache(question, response, timeToLive)
c.storeCache(transport, question, response, timeToLive)
}

return response, err
Expand Down Expand Up @@ -177,7 +195,7 @@ func (c *Client) Lookup(ctx context.Context, transport Transport, domain string,
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
})
}, transport)
if err != ErrNotCached {
return response, err
}
Expand All @@ -186,7 +204,7 @@ func (c *Client) Lookup(ctx context.Context, transport Transport, domain string,
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
})
}, transport)
if err != ErrNotCached {
return response, err
}
Expand All @@ -195,12 +213,12 @@ func (c *Client) Lookup(ctx context.Context, transport Transport, domain string,
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
})
}, transport)
response6, _ := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
})
}, transport)
if len(response4) > 0 || len(response6) > 0 {
return sortAddresses(response4, response6, strategy), nil
}
Expand Down Expand Up @@ -250,7 +268,7 @@ func (c *Client) Lookup(ctx context.Context, transport Transport, domain string,
})
}
}
c.storeCache(question4, message4, DefaultTTL)
c.storeCache(transport, question4, message4, DefaultTTL)
}
if strategy != DomainStrategyUseIPv4 {
question6 := dns.Question{
Expand Down Expand Up @@ -278,7 +296,7 @@ func (c *Client) Lookup(ctx context.Context, transport Transport, domain string,
})
}
}
c.storeCache(question6, message6, DefaultTTL)
c.storeCache(transport, question6, message6, DefaultTTL)
}
}
return response, err
Expand All @@ -292,13 +310,27 @@ func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy Doma
}
}

func (c *Client) storeCache(question dns.Question, message *dns.Msg, timeToLive int) {
func (c *Client) storeCache(transport Transport, question dns.Question, message *dns.Msg, timeToLive int) {
if c.disableExpire {
c.cache.Store(question, message)
if !c.independentCache {
c.cache.Store(question, message)
} else {
c.transportCache.Store(transportCacheKey{
Question: question,
transportName: transport.Name(),
}, message)
}
return
}
expireAt := time.Now().Add(time.Second * time.Duration(timeToLive))
c.cache.StoreWithExpire(question, message, expireAt)
if !c.independentCache {
c.cache.StoreWithExpire(question, message, expireAt)
} else {
c.transportCache.StoreWithExpire(transportCacheKey{
Question: question,
transportName: transport.Name(),
}, message, expireAt)
}
}

func (c *Client) exchangeToLookup(ctx context.Context, transport Transport, message *dns.Msg, question dns.Question) (*dns.Msg, error) {
Expand Down Expand Up @@ -370,7 +402,7 @@ func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name
}
disableCache := c.disableCache || DisableCacheFromContext(ctx)
if !disableCache {
cachedAddresses, err := c.questionCache(question)
cachedAddresses, err := c.questionCache(question, transport)
if err != ErrNotCached {
return cachedAddresses, err
}
Expand All @@ -388,33 +420,59 @@ func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name
return messageToAddresses(response)
}

func (c *Client) questionCache(question dns.Question) ([]netip.Addr, error) {
response, _ := c.loadResponse(question)
func (c *Client) questionCache(question dns.Question, transport Transport) ([]netip.Addr, error) {
response, _ := c.loadResponse(question, transport)
if response == nil {
return nil, ErrNotCached
}
return messageToAddresses(response)
}

func (c *Client) loadResponse(question dns.Question) (*dns.Msg, int) {
func (c *Client) loadResponse(question dns.Question, transport Transport) (*dns.Msg, int) {
var (
response *dns.Msg
loaded bool
)
if c.disableExpire {
response, loaded := c.cache.Load(question)
if !c.independentCache {
response, loaded = c.cache.Load(question)
} else {
response, loaded = c.transportCache.Load(transportCacheKey{
Question: question,
transportName: transport.Name(),
})
}
if !loaded {
return nil, 0
}
return response.Copy(), 0
} else {
cachedAnswer, expireAt, loaded := c.cache.LoadWithExpire(question)
var expireAt time.Time
if !c.independentCache {
response, expireAt, loaded = c.cache.LoadWithExpire(question)
} else {
response, expireAt, loaded = c.transportCache.LoadWithExpire(transportCacheKey{
Question: question,
transportName: transport.Name(),
})
}
if !loaded {
return nil, 0
}
timeNow := time.Now()
if timeNow.After(expireAt) {
c.cache.Delete(question)
if !c.independentCache {
c.cache.Delete(question)
} else {
c.transportCache.Delete(transportCacheKey{
Question: question,
transportName: transport.Name(),
})
}
return nil, 0
}
var originTTL int
for _, recordList := range [][]dns.RR{cachedAnswer.Answer, cachedAnswer.Ns, cachedAnswer.Extra} {
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
if originTTL == 0 || record.Header().Ttl > 0 && int(record.Header().Ttl) < originTTL {
originTTL = int(record.Header().Ttl)
Expand All @@ -425,7 +483,7 @@ func (c *Client) loadResponse(question dns.Question) (*dns.Msg, int) {
if nowTTL < 0 {
nowTTL = 0
}
response := cachedAnswer.Copy()
response = response.Copy()
if originTTL > 0 {
duration := uint32(originTTL - nowTTL)
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
Expand Down

0 comments on commit 18b00d0

Please sign in to comment.