From a03a56c89026753989a4b06787ef1a08dfbabb9c Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Tue, 20 Sep 2022 15:38:01 +0300 Subject: [PATCH] Pull request: proxy: added HTTP/3 support to the DNS-over-HTTPS server implementation Merge in DNS/dnsproxy from doh3server to master Squashed commit of the following: commit dd7f6ecb0264afd16ee6fcd47ff7bafe06797645 Author: Andrey Meshkov Date: Tue Sep 20 14:17:51 2022 +0300 upstream: fix review comments commit 3b887f614163f4900f75807c990ad2a5d354d3b5 Author: Andrey Meshkov Date: Tue Sep 20 00:14:19 2022 +0300 proxy: added address validation logic commit b29dc3c3b6746ad5be921941904f16ab228b1dab Author: Andrey Meshkov Date: Mon Sep 19 23:31:21 2022 +0300 proxy: fix review comments, general improvements commit 79f47f54adcd30a68a9f7bc0111025ae0a32d99d Author: Andrey Meshkov Date: Mon Sep 19 20:43:26 2022 +0300 upstream: several improvements in DoH3 and DoQ upstreams The previous implementation weren't able to properly handle a situation when the server was restarted. This commit greatly improves the overall stability. commit 59cf92b6097d78acf6f088057134888993f7ca43 Author: Andrey Meshkov Date: Sat Sep 17 02:51:40 2022 +0300 proxy: remoteAddr for DoH depends on HTTP version now commit 804ddedd2807870b7d36dae5ce9857de3a7f7286 Author: Andrey Meshkov Date: Sat Sep 17 01:53:32 2022 +0300 proxy: added HTTP/3 support to the DNS-over-HTTPS server implementation The implementation follows the old approach that was used in dnsproxy, i.e. it adds another bunch of "listeners", the new ones are for HTTP/3. HTTP/3 support is not enabled by default, it should be enabled explicitly by setting HTTP3 field of proxy.Config to true. The "--http3" command-line argument now controls DoH3 support on both the client-side and the server-side. There's one more important change that was made while refactoring the code. Previously, we were creating a separate http.Server instance for every listen address that's used. It is unclear to me what's the reason for that since a single instance can be used to serve on every address. This mistake is fixed now. --- README.md | 5 + fastip/fastest.go | 3 +- fastip/ping_test.go | 6 +- go.mod | 1 + go.sum | 2 + main.go | 4 +- proxy/config.go | 1 + proxy/proxy.go | 107 ++-- proxy/proxy_test.go | 8 +- proxy/server.go | 9 +- proxy/server_dnscrypt.go | 23 +- proxy/server_https.go | 130 +++-- proxy/server_https_test.go | 194 +++++--- proxy/server_quic.go | 123 ++++- proxy/server_quic_test.go | 2 +- upstream/bootstrap.go | 13 + upstream/upstream_doh.go | 120 ++++- upstream/upstream_doh_test.go | 103 +++- upstream/upstream_quic.go | 207 +++++--- upstream/upstream_quic_test.go | 164 ++++++- vendor/github.com/bluele/gcache/LICENSE | 21 + vendor/github.com/bluele/gcache/README.md | 320 ++++++++++++ vendor/github.com/bluele/gcache/arc.go | 456 ++++++++++++++++++ vendor/github.com/bluele/gcache/cache.go | 205 ++++++++ vendor/github.com/bluele/gcache/clock.go | 53 ++ vendor/github.com/bluele/gcache/lfu.go | 377 +++++++++++++++ vendor/github.com/bluele/gcache/lru.go | 317 ++++++++++++ vendor/github.com/bluele/gcache/simple.go | 307 ++++++++++++ .../github.com/bluele/gcache/singleflight.go | 82 ++++ vendor/github.com/bluele/gcache/stats.go | 53 ++ vendor/github.com/bluele/gcache/utils.go | 15 + vendor/modules.txt | 3 + 32 files changed, 3140 insertions(+), 294 deletions(-) create mode 100644 vendor/github.com/bluele/gcache/LICENSE create mode 100644 vendor/github.com/bluele/gcache/README.md create mode 100644 vendor/github.com/bluele/gcache/arc.go create mode 100644 vendor/github.com/bluele/gcache/cache.go create mode 100644 vendor/github.com/bluele/gcache/clock.go create mode 100644 vendor/github.com/bluele/gcache/lfu.go create mode 100644 vendor/github.com/bluele/gcache/lru.go create mode 100644 vendor/github.com/bluele/gcache/simple.go create mode 100644 vendor/github.com/bluele/gcache/singleflight.go create mode 100644 vendor/github.com/bluele/gcache/stats.go create mode 100644 vendor/github.com/bluele/gcache/utils.go diff --git a/README.md b/README.md index ff36c26c6..425258f01 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,11 @@ Runs a DNS-over-HTTPS proxy on `127.0.0.1:443`. ./dnsproxy -l 127.0.0.1 --https-port=443 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0 ``` +Runs a DNS-over-HTTPS proxy on `127.0.0.1:443` with HTTP/3 support. +```shell +./dnsproxy -l 127.0.0.1 --https-port=443 --http3 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0 +``` + Runs a DNS-over-QUIC proxy on `127.0.0.1:853`. ```shell ./dnsproxy -l 127.0.0.1 --quic-port=853 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0 diff --git a/fastip/fastest.go b/fastip/fastest.go index 685cbd321..c30d43339 100644 --- a/fastip/fastest.go +++ b/fastip/fastest.go @@ -6,11 +6,10 @@ import ( "sync" "time" - "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/dnsproxy/proxyutil" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" + "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) diff --git a/fastip/ping_test.go b/fastip/ping_test.go index 23fc851ef..be1cf2606 100644 --- a/fastip/ping_test.go +++ b/fastip/ping_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -99,7 +100,7 @@ func TestFastestAddr_PingAll_cache(t *testing.T) { t.Run("not_cached", func(t *testing.T) { listener, err := net.Listen("tcp", ":0") require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, listener.Close()) }) + testutil.CleanupAndRequireSuccess(t, listener.Close) ip := net.IP{127, 0, 0, 1} f := NewFastestAddr() @@ -138,8 +139,7 @@ func listen(t *testing.T, ip net.IP) (port uint) { l, err := net.Listen("tcp", netutil.IPPort{IP: ip, Port: 0}.String()) require.NoError(t, err) - - t.Cleanup(func() { require.NoError(t, l.Close()) }) + testutil.CleanupAndRequireSuccess(t, l.Close) return uint(l.Addr().(*net.TCPAddr).Port) } diff --git a/go.mod b/go.mod index 52232f058..6485dd618 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/ameshkov/dnscrypt/v2 v2.2.5 github.com/ameshkov/dnsstamps v1.0.3 github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 + github.com/bluele/gcache v0.0.2 github.com/jessevdk/go-flags v1.5.0 github.com/lucas-clemente/quic-go v0.29.0 github.com/miekg/dns v1.1.50 diff --git a/go.sum b/go.sum index 484a8304e..1bff13de1 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1O github.com/ameshkov/dnsstamps v1.0.3/go.mod h1:Ii3eUu73dx4Vw5O4wjzmT5+lkCwovjzaEZZ4gKyIH5A= github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 h1:0b2vaepXIfMsG++IsjHiI2p4bxALD1Y2nQKGMR5zDQM= github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA= +github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw= +github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0t/fp0= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/main.go b/main.go index df92091f1..6474e2c9c 100644 --- a/main.go +++ b/main.go @@ -81,8 +81,7 @@ type Options struct { DNSCryptConfigPath string `yaml:"dnscrypt-config" short:"g" long:"dnscrypt-config" description:"Path to a file with DNSCrypt configuration. You can generate one using https://github.com/ameshkov/dnscrypt"` // HTTP3 controls whether HTTP/3 is enabled for this instance of dnsproxy. - // At this point it only enables it for upstreams, but in the future it will - // also enable it for the server. + // It enables HTTP/3 support for both the DoH upstreams and the DoH server. HTTP3 bool `yaml:"http3" long:"http3" description:"Enable HTTP/3 support" optional:"yes" optional-value:"false"` // Upstream DNS servers settings @@ -274,6 +273,7 @@ func createProxyConfig(options *Options) proxy.Config { CacheMaxTTL: options.CacheMaxTTL, CacheOptimistic: options.CacheOptimistic, RefuseAny: options.RefuseAny, + HTTP3: options.HTTP3, // TODO(e.burkov): The following CIDRs are aimed to match any // address. This is not quite proper approach to be used by // default so think about configuring it. diff --git a/proxy/config.go b/proxy/config.go index 0e10bdd93..f9e1bfe43 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -54,6 +54,7 @@ type Config struct { // -- TLSConfig *tls.Config // necessary for TLS, HTTPS, QUIC + HTTP3 bool // if true, HTTPS server will also support HTTP/3 DNSCryptProviderName string // DNSCrypt provider name DNSCryptResolverCert *dnscrypt.Cert // DNSCrypt resolver certificate diff --git a/proxy/proxy.go b/proxy/proxy.go index 816a26a8b..8528ab9d1 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -4,6 +4,7 @@ package proxy import ( "fmt" + "io" "net" "net/http" "sync" @@ -18,6 +19,7 @@ import ( "github.com/AdguardTeam/golibs/netutil" "github.com/ameshkov/dnscrypt/v2" "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/http3" "github.com/miekg/dns" gocache "github.com/patrickmn/go-cache" ) @@ -65,15 +67,17 @@ type Proxy struct { // Listeners // -- - udpListen []*net.UDPConn // UDP listen connections - tcpListen []net.Listener // TCP listeners - tlsListen []net.Listener // TLS listeners - quicListen []quic.Listener // QUIC listeners - httpsListen []net.Listener // HTTPS listeners - httpsServer []*http.Server // HTTPS server instance - dnsCryptUDPListen []*net.UDPConn // UDP listen connections for DNSCrypt - dnsCryptTCPListen []net.Listener // TCP listeners for DNSCrypt - dnsCryptServer *dnscrypt.Server // DNSCrypt server instance + udpListen []*net.UDPConn // UDP listen connections + tcpListen []net.Listener // TCP listeners + tlsListen []net.Listener // TLS listeners + quicListen []quic.EarlyListener // QUIC listeners + httpsListen []net.Listener // HTTPS listeners + httpsServer *http.Server // HTTPS server instance + h3Listen []quic.EarlyListener // HTTP/3 listeners + h3Server *http3.Server // HTTP/3 server instance + dnsCryptUDPListen []*net.UDPConn // UDP listen connections for DNSCrypt + dnsCryptTCPListen []net.Listener // TCP listeners for DNSCrypt + dnsCryptServer *dnscrypt.Server // DNSCrypt server instance // Upstream // -- @@ -145,19 +149,6 @@ func (p *Proxy) Init() (err error) { p.requestGoroutinesSema = newNoopSemaphore() } - if p.DNSCryptResolverCert != nil && p.DNSCryptProviderName != "" { - log.Info("Initializing DNSCrypt: %s", p.DNSCryptProviderName) - p.dnsCryptServer = &dnscrypt.Server{ - ProviderName: p.DNSCryptProviderName, - ResolverCert: p.DNSCryptResolverCert, - Handler: &dnsCryptHandler{ - proxy: p, - - requestGoroutinesSema: p.requestGoroutinesSema, - }, - } - } - p.udpOOBSize = proxyutil.UDPGetOOBSize() p.bytesPool = &sync.Pool{ New: func() interface{} { @@ -212,6 +203,17 @@ func (p *Proxy) Start() (err error) { return nil } +// closeAll closes all elements in the toClose slice and if there's any error +// appends it to the errs slice. +func closeAll[T io.Closer](toClose []T, errs *[]error) { + for _, c := range toClose { + err := c.Close() + if err != nil { + *errs = append(*errs, err) + } + } +} + // Stop stops the proxy server including all its listeners func (p *Proxy) Stop() error { log.Info("Stopping the DNS proxy server") @@ -225,61 +227,38 @@ func (p *Proxy) Stop() error { errs := []error{} - for _, l := range p.tcpListen { - err := l.Close() - if err != nil { - errs = append(errs, fmt.Errorf("closing tcp listening socket: %w", err)) - } - } + closeAll(p.tcpListen, &errs) p.tcpListen = nil - for _, l := range p.udpListen { - err := l.Close() - if err != nil { - errs = append(errs, fmt.Errorf("closing udp listening socket: %w", err)) - } - } + closeAll(p.udpListen, &errs) p.udpListen = nil - for _, l := range p.tlsListen { - err := l.Close() - if err != nil { - errs = append(errs, fmt.Errorf("closing tls listening socket: %w", err)) - } - } + closeAll(p.tlsListen, &errs) p.tlsListen = nil - for _, srv := range p.httpsServer { - err := srv.Close() - if err != nil { - errs = append(errs, fmt.Errorf("closing https server: %w", err)) - } + if p.httpsServer != nil { + closeAll([]io.Closer{p.httpsServer}, &errs) + p.httpsServer = nil + + // No need to close these since they're closed by httpsServer.Close(). + p.httpsListen = nil } - p.httpsListen = nil - p.httpsServer = nil - for _, l := range p.quicListen { - err := l.Close() - if err != nil { - errs = append(errs, fmt.Errorf("closing quic listener: %w", err)) - } + if p.h3Server != nil { + closeAll([]io.Closer{p.h3Server}, &errs) + p.h3Server = nil } + + closeAll(p.h3Listen, &errs) + p.h3Listen = nil + + closeAll(p.quicListen, &errs) p.quicListen = nil - for _, l := range p.dnsCryptUDPListen { - err := l.Close() - if err != nil { - errs = append(errs, fmt.Errorf("closing dnscrypt udp listening socket: %w", err)) - } - } + closeAll(p.dnsCryptUDPListen, &errs) p.dnsCryptUDPListen = nil - for _, l := range p.dnsCryptTCPListen { - err := l.Close() - if err != nil { - errs = append(errs, fmt.Errorf("closing dnscrypt tcp listening socket: %w", err)) - } - } + closeAll(p.dnsCryptTCPListen, &errs) p.dnsCryptTCPListen = nil p.started = false diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index e68cd972f..6596d1319 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -736,9 +736,7 @@ func TestResponseInRequest(t *testing.T) { func TestNoQuestion(t *testing.T) { dnsProxy := createTestProxy(t, nil) require.NoError(t, dnsProxy.Start()) - t.Cleanup(func() { - require.NoError(t, dnsProxy.Stop()) - }) + testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop) addr := dnsProxy.Addr(ProtoUDP) client := &dns.Client{Net: "udp", Timeout: 500 * time.Millisecond} @@ -780,9 +778,7 @@ func (wu *funcUpstream) Address() string { func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) { dnsProxy := createTestProxy(t, nil) require.NoError(t, dnsProxy.Start()) - t.Cleanup(func() { - require.NoError(t, dnsProxy.Stop()) - }) + testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop) exchangeFunc := func(m *dns.Msg) (resp *dns.Msg, err error) { resp = &dns.Msg{} diff --git a/proxy/server.go b/proxy/server.go index e9fc66c70..a406b4ae0 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -6,6 +6,7 @@ import ( "time" "github.com/AdguardTeam/golibs/log" + "github.com/lucas-clemente/quic-go" "github.com/miekg/dns" ) @@ -53,8 +54,12 @@ func (p *Proxy) startListeners() error { go p.tcpPacketLoop(l, ProtoTLS, p.requestGoroutinesSema) } - for i := range p.httpsServer { - go p.listenHTTPS(p.httpsServer[i], p.httpsListen[i]) + for _, l := range p.httpsListen { + go func(l net.Listener) { _ = p.httpsServer.Serve(l) }(l) + } + + for _, l := range p.h3Listen { + go func(l quic.EarlyListener) { _ = p.h3Server.ServeListener(l) }(l) } for _, l := range p.quicListen { diff --git a/proxy/server_dnscrypt.go b/proxy/server_dnscrypt.go index 308d068ab..afc7b68f8 100644 --- a/proxy/server_dnscrypt.go +++ b/proxy/server_dnscrypt.go @@ -4,12 +4,33 @@ import ( "fmt" "net" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/ameshkov/dnscrypt/v2" "github.com/miekg/dns" ) -func (p *Proxy) createDNSCryptListeners() error { +func (p *Proxy) createDNSCryptListeners() (err error) { + if len(p.DNSCryptUDPListenAddr) == 0 && len(p.DNSCryptTCPListenAddr) == 0 { + // Do nothing if DNSCrypt listen addresses are not specified. + return nil + } + + if p.DNSCryptResolverCert == nil || p.DNSCryptProviderName == "" { + return errors.Error("invalid DNSCrypt configuration: no certificate or provider name") + } + + log.Info("Initializing DNSCrypt: %s", p.DNSCryptProviderName) + p.dnsCryptServer = &dnscrypt.Server{ + ProviderName: p.DNSCryptProviderName, + ResolverCert: p.DNSCryptResolverCert, + Handler: &dnsCryptHandler{ + proxy: p, + + requestGoroutinesSema: p.requestGoroutinesSema, + }, + } + for _, a := range p.DNSCryptUDPListenAddr { log.Info("Creating a DNSCrypt UDP listener") udpListen, err := net.ListenUDP("udp", a) diff --git a/proxy/server_https.go b/proxy/server_https.go index 7324fc14e..0acad3e26 100644 --- a/proxy/server_https.go +++ b/proxy/server_https.go @@ -1,6 +1,7 @@ package proxy import ( + "crypto/tls" "encoding/base64" "fmt" "io" @@ -11,56 +12,107 @@ import ( "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/http3" "github.com/miekg/dns" "golang.org/x/net/http2" ) -func (p *Proxy) createHTTPSListeners() error { - for _, a := range p.HTTPSListenAddr { - log.Info("Creating an HTTPS server") - tcpListen, err := net.ListenTCP("tcp", a) - if err != nil { - return fmt.Errorf("starting https listener: %w", err) - } - p.httpsListen = append(p.httpsListen, tcpListen) - log.Info("Listening to https://%s", tcpListen.Addr()) +// listenHTTP creates instances of TLS listeners that will be used to run an +// H1/H2 server. Returns the address the listener actually listens to (useful +// in the case if port 0 is specified). +func (p *Proxy) listenHTTP(addr *net.TCPAddr) (laddr *net.TCPAddr, err error) { + tcpListen, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, fmt.Errorf("tcp listener: %w", err) + } + log.Info("Listening to https://%s", tcpListen.Addr()) - tlsConfig := p.TLSConfig.Clone() - tlsConfig.NextProtos = []string{http2.NextProtoTLS, "http/1.1"} + tlsConfig := p.TLSConfig.Clone() + tlsConfig.NextProtos = []string{http2.NextProtoTLS, "http/1.1"} - srv := &http.Server{ - TLSConfig: tlsConfig, - Handler: p, - ReadHeaderTimeout: defaultTimeout, - WriteTimeout: defaultTimeout, - } + tlsListen := tls.NewListener(tcpListen, tlsConfig) + p.httpsListen = append(p.httpsListen, tlsListen) - p.httpsServer = append(p.httpsServer, srv) + return tcpListen.Addr().(*net.TCPAddr), nil +} + +// listenH3 creates instances of QUIC listeners that will be used for running +// an HTTP/3 server. +func (p *Proxy) listenH3(addr *net.UDPAddr) (err error) { + tlsConfig := p.TLSConfig.Clone() + tlsConfig.NextProtos = []string{"h3"} + quicListen, err := quic.ListenAddrEarly(addr.String(), tlsConfig, newServerQUICConfig()) + if err != nil { + return fmt.Errorf("quic listener: %w", err) } + log.Info("Listening to h3://%s", quicListen.Addr()) + + p.h3Listen = append(p.h3Listen, quicListen) return nil } -// serveHttps starts the HTTPS server -func (p *Proxy) listenHTTPS(srv *http.Server, l net.Listener) { - log.Info("Listening to DNS-over-HTTPS on %s", l.Addr()) - err := srv.ServeTLS(l, "", "") +// createHTTPSListeners creates TCP/UDP listeners and HTTP/H3 servers. +func (p *Proxy) createHTTPSListeners() (err error) { + p.httpsServer = &http.Server{ + Handler: &proxyHTTPHandler{ + proxy: p, + h3: false, + }, + ReadHeaderTimeout: defaultTimeout, + WriteTimeout: defaultTimeout, + } + + if p.HTTP3 { + p.h3Server = &http3.Server{ + Handler: &proxyHTTPHandler{ + proxy: p, + h3: true, + }, + } + } + + for _, addr := range p.HTTPSListenAddr { + log.Info("Creating an HTTPS server") - if err != http.ErrServerClosed { - log.Info("HTTPS server was closed unexpectedly: %s", err) - } else { - log.Info("HTTPS server was closed") + tcpAddr, err := p.listenHTTP(addr) + if err != nil { + return fmt.Errorf("failed to start HTTPS server on %s: %w", addr, err) + } + + if p.HTTP3 { + // HTTP/3 server listens to the same pair IP:port as the one HTTP/2 + // server listens to. + udpAddr := &net.UDPAddr{IP: tcpAddr.IP, Port: tcpAddr.Port} + err = p.listenH3(udpAddr) + if err != nil { + return fmt.Errorf("failed to start HTTP/3 server on %s: %w", udpAddr, err) + } + } } + + return nil +} + +// proxyHTTPHandler implements http.Handler and processes DoH queries. +type proxyHTTPHandler struct { + // h3 is true if this is an HTTP/3 requests handler. + h3 bool + proxy *Proxy } -// ServeHTTP is the http.RequestHandler implementation that handles DoH queries +// type check +var _ http.Handler = &proxyHTTPHandler{} + +// ServeHTTP is the http.Handler implementation that handles DoH queries. // Here is what it returns: // // - http.StatusBadRequest if there is no DNS request data; // - http.StatusUnsupportedMediaType if request content type is not // "application/dns-message"; // - http.StatusMethodNotAllowed if request method is not GET or POST. -func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *proxyHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Tracef("Incoming HTTPS request on %s", r.URL) var buf []byte @@ -103,12 +155,12 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - addr, prx, err := remoteAddr(r) + addr, prx, err := remoteAddr(r, h.h3) if err != nil { log.Debug("warning: getting real ip: %s", err) } - d := p.newDNSContext(ProtoHTTPS, req) + d := h.proxy.newDNSContext(ProtoHTTPS, req) d.Addr = addr d.HTTPRequest = r d.HTTPResponseWriter = w @@ -116,13 +168,13 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { if prx != nil { ip, _ := netutil.IPAndPortFromAddr(prx) log.Debug("request came from proxy server %s", prx) - if !p.proxyVerifier.Contains(ip) { + if !h.proxy.proxyVerifier.Contains(ip) { log.Debug("proxy %s is not trusted, using original remote addr", ip) d.Addr = prx } } - err = p.handleDNSRequest(d) + err = h.proxy.handleDNSRequest(d) if err != nil { log.Tracef("error handling DNS (%s) request: %s", d.Proto, err) } @@ -187,7 +239,7 @@ func realIPFromHdrs(r *http.Request) (realIP net.IP) { // remoteAddr returns the real client's address and the IP address of the latest // proxy server if any. -func remoteAddr(r *http.Request) (addr, prx net.Addr, err error) { +func remoteAddr(r *http.Request, h3 bool) (addr, prx net.Addr, err error) { var hostStr, portStr string if hostStr, portStr, err = net.SplitHostPort(r.RemoteAddr); err != nil { return nil, nil, err @@ -206,13 +258,15 @@ func remoteAddr(r *http.Request) (addr, prx net.Addr, err error) { if realIP := realIPFromHdrs(r); realIP != nil { log.Tracef("Using IP address from HTTP request: %s", realIP) - // TODO(a.garipov): Use net.UDPAddr here and below when - // necessary when we start supporting HTTP/3. - // // TODO(a.garipov): Add port if we can get it from headers like // X-Real-Port, X-Forwarded-Port, etc. - addr = &net.TCPAddr{IP: realIP, Port: 0} - prx = &net.TCPAddr{IP: host, Port: port} + if h3 { + addr = &net.UDPAddr{IP: realIP, Port: 0} + prx = &net.UDPAddr{IP: host, Port: port} + } else { + addr = &net.TCPAddr{IP: realIP, Port: 0} + prx = &net.TCPAddr{IP: host, Port: port} + } return addr, prx, nil } diff --git a/proxy/server_https_test.go b/proxy/server_https_test.go index 5685f1888..68d14289b 100644 --- a/proxy/server_https_test.go +++ b/proxy/server_https_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "fmt" "io" "net" "net/http" @@ -12,12 +13,53 @@ import ( "testing" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/http3" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestHttpsProxy(t *testing.T) { + testCases := []struct { + name string + http3 bool + }{{ + name: "https_proxy", + http3: false, + }, { + name: "h3_proxy", + http3: true, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Prepare dnsProxy with its configuration. + tlsConf, caPem := createServerTLSConfig(t) + dnsProxy := createTestProxy(t, tlsConf) + dnsProxy.HTTP3 = tc.http3 + + // Run the proxy. + err := dnsProxy.Start() + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop) + + // Create the HTTP client that we'll be using for this test. + client := createTestHTTPClient(dnsProxy, caPem, tc.http3) + + // Prepare a test message to be sent to the server. + msg := createTestMessage() + + // Send the test message and check if the response is what we + // expected. + resp := sendTestDoHMessage(t, client, msg, nil) + requireResponse(t, msg, resp) + }) + } +} + +func TestHttpsProxyTrustedProxies(t *testing.T) { // Prepare the proxy server. tlsConf, caPem := createServerTLSConfig(t) dnsProxy := createTestProxy(t, tlsConf) @@ -29,30 +71,7 @@ func TestHttpsProxy(t *testing.T) { return dnsProxy.Resolve(d) } - roots := x509.NewCertPool() - ok := roots.AppendCertsFromPEM(caPem) - require.True(t, ok) - - dialer := &net.Dialer{ - Timeout: defaultTimeout, - } - dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) { - // Route request to the DNS-over-HTTPS server address. - return dialer.DialContext(ctx, network, dnsProxy.Addr(ProtoHTTPS).String()) - } - - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - ServerName: tlsServerName, - RootCAs: roots, - }, - DisableCompression: true, - DialContext: dialContext, - } - client := http.Client{ - Transport: transport, - Timeout: defaultTimeout, - } + client := createTestHTTPClient(dnsProxy, caPem, false) clientIP, proxyIP := net.IP{1, 2, 3, 4}, net.IP{127, 0, 0, 1} msg := createTestMessage() @@ -63,42 +82,14 @@ func TestHttpsProxy(t *testing.T) { // Start listening. serr := dnsProxy.Start() require.NoError(t, serr) - t.Cleanup(func() { - derr := dnsProxy.Stop() - require.NoError(t, derr) - }) + testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop) - packed, err := msg.Pack() - require.NoError(t, err) - - b := bytes.NewBuffer(packed) - req, err := http.NewRequest("POST", "https://test.com", b) - require.NoError(t, err) - - req.Header.Set("Content-Type", "application/dns-message") - req.Header.Set("Accept", "application/dns-message") - // IP "1.2.3.4" will be used as a client address in DNSContext. - req.Header.Set("X-Forwarded-For", strings.Join( - []string{clientIP.String(), proxyIP.String()}, - ",", - )) - - resp, err := client.Do(req) - require.NoError(t, err) - - if resp != nil && resp.Body != nil { - t.Cleanup(func() { - resp.Body.Close() - }) + hdrs := map[string]string{ + "X-Forwarded-For": strings.Join([]string{clientIP.String(), proxyIP.String()}, ","), } - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - - reply := &dns.Msg{} - err = reply.Unpack(body) - require.NoError(t, err) - requireResponse(t, msg, reply) + resp := sendTestDoHMessage(t, client, msg, hdrs) + requireResponse(t, msg, resp) } t.Run("success", func(t *testing.T) { @@ -300,7 +291,7 @@ func TestRemoteAddr(t *testing.T) { } t.Run(tc.name, func(t *testing.T) { - addr, prx, err := remoteAddr(r) + addr, prx, err := remoteAddr(r, false) if tc.wantErr != "" { assert.Equal(t, tc.wantErr, err.Error()) @@ -317,3 +308,90 @@ func TestRemoteAddr(t *testing.T) { }) } } + +// sendTestDoHMessage sends the specified DNS message using client and returns +// the DNS response. +func sendTestDoHMessage( + t *testing.T, + client *http.Client, + m *dns.Msg, + hdrs map[string]string, +) (resp *dns.Msg) { + packed, err := m.Pack() + require.NoError(t, err) + + b := bytes.NewBuffer(packed) + u := fmt.Sprintf("https://%s/dns-query", tlsServerName) + req, err := http.NewRequest(http.MethodPost, u, b) + require.NoError(t, err) + + req.Header.Set("Content-Type", "application/dns-message") + req.Header.Set("Accept", "application/dns-message") + + for k, v := range hdrs { + req.Header.Set(k, v) + } + + httpResp, err := client.Do(req) + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, httpResp.Body.Close) + + body, err := io.ReadAll(httpResp.Body) + require.NoError(t, err) + + resp = &dns.Msg{} + err = resp.Unpack(body) + require.NoError(t, err) + + return resp +} + +// createTestHTTPClient creates an *http.Client that will be used to send +// requests to the specified dnsProxy. +func createTestHTTPClient(dnsProxy *Proxy, caPem []byte, http3Enabled bool) (client *http.Client) { + // prepare roots list so that the server cert was successfully validated. + roots := x509.NewCertPool() + roots.AppendCertsFromPEM(caPem) + tlsClientConfig := &tls.Config{ + ServerName: tlsServerName, + RootCAs: roots, + } + + var transport http.RoundTripper + + if http3Enabled { + transport = &http3.RoundTripper{ + Dial: func( + ctx context.Context, + _ string, + tlsCfg *tls.Config, + cfg *quic.Config, + ) (quic.EarlyConnection, error) { + addr := dnsProxy.Addr(ProtoHTTPS).String() + return quic.DialAddrEarlyContext(ctx, addr, tlsCfg, cfg) + }, + TLSClientConfig: tlsClientConfig, + QuicConfig: &quic.Config{}, + DisableCompression: true, + } + } else { + dialer := &net.Dialer{ + Timeout: defaultTimeout, + } + dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) { + // Route request to the DNS-over-HTTPS server address. + return dialer.DialContext(ctx, network, dnsProxy.Addr(ProtoHTTPS).String()) + } + + transport = &http.Transport{ + TLSClientConfig: tlsClientConfig, + DisableCompression: true, + DialContext: dialContext, + } + } + + return &http.Client{ + Transport: transport, + Timeout: defaultTimeout, + } +} diff --git a/proxy/server_quic.go b/proxy/server_quic.go index 59958f3e0..ce8adc238 100644 --- a/proxy/server_quic.go +++ b/proxy/server_quic.go @@ -5,12 +5,14 @@ import ( "encoding/binary" "errors" "fmt" - "strings" + "io" + "math" + "net" "time" "github.com/AdguardTeam/dnsproxy/proxyutil" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/stringutil" + "github.com/bluele/gcache" "github.com/lucas-clemente/quic-go" "github.com/miekg/dns" ) @@ -31,6 +33,18 @@ var compatProtoDQ = []string{NextProtoDQ, "doq-i02", "doq-i00", "dq"} // better for clients written with ngtcp2. const maxQUICIdleTimeout = 5 * time.Minute +// quicAddrValidatorCacheSize is the size of the cache that we use in the QUIC +// address validator. The value is chosen arbitrarily and we should consider +// making it configurable. +// TODO(ameshkov): make it configurable. +const quicAddrValidatorCacheSize = 1000 + +// quicAddrValidatorCacheTTL is time-to-live for cache items in the QUIC address +// validator. The value is chosen arbitrarily and we should consider making it +// configurable. +// TODO(ameshkov): make it configurable. +const quicAddrValidatorCacheTTL = 30 * time.Minute + const ( // DoQCodeNoError is used when the connection or stream needs to be closed, // but there is no error to signal. @@ -44,14 +58,19 @@ const ( DoQCodeProtocolError quic.ApplicationErrorCode = 2 ) +// createQUICListeners creates QUIC listeners for the DoQ server. func (p *Proxy) createQUICListeners() error { for _, a := range p.QUICListenAddr { log.Info("Creating a QUIC listener") tlsConfig := p.TLSConfig.Clone() tlsConfig.NextProtos = compatProtoDQ - quicListen, err := quic.ListenAddr(a.String(), tlsConfig, &quic.Config{MaxIdleTimeout: maxQUICIdleTimeout}) + quicListen, err := quic.ListenAddrEarly( + a.String(), + tlsConfig, + newServerQUICConfig(), + ) if err != nil { - return fmt.Errorf("starting quic listener: %w", err) + return fmt.Errorf("quic listener: %w", err) } p.quicListen = append(p.quicListen, quicListen) @@ -63,13 +82,14 @@ func (p *Proxy) createQUICListeners() error { // quicPacketLoop listens for incoming QUIC packets. // // See also the comment on Proxy.requestGoroutinesSema. -func (p *Proxy) quicPacketLoop(l quic.Listener, requestGoroutinesSema semaphore) { +func (p *Proxy) quicPacketLoop(l quic.EarlyListener, requestGoroutinesSema semaphore) { log.Info("Entering the DNS-over-QUIC listener loop on %s", l.Addr()) for { conn, err := l.Accept(context.Background()) + if err != nil { if isQUICNonCrit(err) { - log.Tracef("quic connection closed or timeout: %s", err) + log.Tracef("quic connection closed or timed out: %s", err) } else { log.Error("reading from quic listen: %s", err) } @@ -140,7 +160,10 @@ func (p *Proxy) handleQUICStream(stream quic.Stream, conn quic.Connection) { buf := *bufPtr n, err := stream.Read(buf) - if n < minDNSPacketSize { + // Note that io.EOF does not really mean that there's any error, this is + // just a signal that there will be no data to read anymore from this + // stream. + if (err != nil && err != io.EOF) || n < minDNSPacketSize { logShortQUICRead(err) return @@ -295,20 +318,42 @@ func logShortQUICRead(err error) { } // isQUICNonCrit returns true if err is a non-critical error, most probably -// a timeout or a closed connection. -// -// TODO(a.garipov): Inspect and rewrite with modern error handling. +// related to the current QUIC implementation. +// TODO(ameshkov): re-test when updating quic-go. func isQUICNonCrit(err error) (ok bool) { if err == nil { return false } - errStr := err.Error() + if errors.Is(err, quic.ErrServerClosed) { + // This error is returned when the QUIC listener was closed by us. This + // is an expected error, we don't need the detailed logs here. + return true + } + + var qAppErr *quic.ApplicationError + if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 0 { + // This error is returned when a QUIC connection was gracefully closed. + // No need to have detailed logs for it either. + return true + } - return strings.Contains(errStr, "server closed") || - stringutil.ContainsFold(errStr, "no recent network activity") || - strings.HasSuffix(errStr, "Application error 0x0") || - errStr == "EOF" + if errors.Is(err, quic.Err0RTTRejected) { + // This error is returned on AcceptStream calls when the server rejects + // 0-RTT for some reason. This is a common scenario, no need for extra + // logs. + return true + } + + var qIdleErr *quic.IdleTimeoutError + if errors.As(err, &qIdleErr) { + // This error is returned when we're trying to accept a new stream from + // a connection that had no activity for over than the keep-alive + // timeout. This is a common scenario, no need for extra logs. + return true + } + + return false } // closeQUICConn quietly closes the QUIC connection. @@ -318,3 +363,51 @@ func closeQUICConn(conn quic.Connection, code quic.ApplicationErrorCode) { log.Debug("failed to close QUIC connection: %v", err) } } + +// newServerQUICConfig creates *quic.Config populated with the default settings. +// This function is supposed to be used for both DoQ and DoH3 server. +func newServerQUICConfig() (conf *quic.Config) { + v := newQUICAddrValidator(quicAddrValidatorCacheSize, quicAddrValidatorCacheTTL) + + return &quic.Config{ + MaxIdleTimeout: maxQUICIdleTimeout, + RequireAddressValidation: v.requiresValidation, + MaxIncomingStreams: math.MaxUint16, + MaxIncomingUniStreams: math.MaxUint16, + } +} + +// quicAddrValidator is a helper struct that holds a small LRU cache of +// addresses for which we do not require address validation. +type quicAddrValidator struct { + cache gcache.Cache + ttl time.Duration +} + +// newQUICAddrValidator initializes a new instance of *quicAddrValidator. +func newQUICAddrValidator(cacheSize int, ttl time.Duration) (v *quicAddrValidator) { + return &quicAddrValidator{ + cache: gcache.New(cacheSize).LRU().Build(), + ttl: ttl, + } +} + +// requiresValidation determines if a QUIC Retry packet should be sent by the +// client. This allows the server to verify the client's address but increases +// the latency. +func (v *quicAddrValidator) requiresValidation(addr net.Addr) (ok bool) { + key := addr.String() + if v.cache.Has(key) { + return false + } + + err := v.cache.SetWithExpire(key, true, v.ttl) + if err != nil { + // Shouldn't happen, since we don't set a serialization function. + panic(fmt.Errorf("quic validator: setting cache item: %w", err)) + } + + // Address not found in the cache so return true to make sure the server + // will require address validation. + return true +} diff --git a/proxy/server_quic_test.go b/proxy/server_quic_test.go index 3896dad0f..5d93f84ea 100644 --- a/proxy/server_quic_test.go +++ b/proxy/server_quic_test.go @@ -34,7 +34,7 @@ func TestQuicProxy(t *testing.T) { addr := dnsProxy.Addr(ProtoQUIC) // Open QUIC connection. - conn, err := quic.DialAddr(addr.String(), tlsConfig, nil) + conn, err := quic.DialAddrEarly(addr.String(), tlsConfig, nil) require.NoError(t, err) defer conn.CloseWithError(DoQCodeNoError, "") diff --git a/upstream/bootstrap.go b/upstream/bootstrap.go index a5e70edb5..4568732d6 100755 --- a/upstream/bootstrap.go +++ b/upstream/bootstrap.go @@ -306,3 +306,16 @@ func (n *bootstrapper) createDialContext(addresses []string) (dialContext dialHa return nil, errors.List("all dialers failed", errs...) } } + +// newContext creates a new context with deadline if needed. If no timeout is +// set cancel would be a simple noop. +func (n *bootstrapper) newContext() (ctx context.Context, cancel context.CancelFunc) { + ctx = context.Background() + cancel = func() {} + + if n.options.Timeout > 0 { + ctx, cancel = context.WithDeadline(ctx, time.Now().Add(n.options.Timeout)) + } + + return ctx, cancel +} diff --git a/upstream/upstream_doh.go b/upstream/upstream_doh.go index 78a93afa9..b8868aa67 100644 --- a/upstream/upstream_doh.go +++ b/upstream/upstream_doh.go @@ -9,7 +9,6 @@ import ( "net" "net/http" "net/url" - "os" "sync" "time" @@ -74,12 +73,7 @@ func newDoH(uu *url.URL, opts *Options) (u Upstream, err error) { quicConfig: &quic.Config{ KeepAlivePeriod: QUICKeepAlivePeriod, - // You can read more on address validation here: - // https://datatracker.ietf.org/doc/html/rfc9000#section-8.1 - // Setting maxOrigins to 1 and tokensPerOrigin to 10 assuming that - // this is more than enough for the way we use it (one connection - // per upstream). - TokenStore: quic.NewLRUTokenStore(1, 10), + TokenStore: newQUICTokenStore(), }, }, nil } @@ -88,17 +82,69 @@ func newDoH(uu *url.URL, opts *Options) (u Upstream, err error) { func (p *dnsOverHTTPS) Address() string { return p.boot.URL.String() } // Exchange implements the Upstream interface for *dnsOverHTTPS. -func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) { +func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { + // Quote from https://www.rfc-editor.org/rfc/rfc8484.html: + // In order to maximize HTTP cache friendliness, DoH clients using media + // formats that include the ID field from the DNS message header, such + // as "application/dns-message", SHOULD use a DNS ID of 0 in every DNS + // request. + id := m.Id + m.Id = 0 + defer func() { + // Restore the original ID to not break compatibility with proxies. + m.Id = id + if resp != nil { + resp.Id = id + } + }() + + // Check if there was already an active client before sending the request. + // We'll only attempt to re-connect if there was one. + hasClient := p.hasClient() + + // Make the first attempt to send the DNS query. + resp, err = p.exchangeHTTPS(m) + + // Make up to 2 attempts to re-create the HTTP client and send the request + // again. There are several cases (mostly, with QUIC) where this workaround + // is necessary to make HTTP client usable. We need to make 2 attempts in + // the case when the connection was closed (due to inactivity for example) + // AND the server refuses to open a 0-RTT connection. + for i := 0; hasClient && p.shouldRetry(err) && i < 2; i++ { + log.Debug("re-creating the HTTP client and retrying due to %v", err) + + p.clientGuard.Lock() + p.client = nil + // Re-create the token store to make sure we're not trying to use invalid + // tokens for 0-RTT. + p.quicConfig.TokenStore = newQUICTokenStore() + p.clientGuard.Unlock() + + resp, err = p.exchangeHTTPS(m) + } + + if err != nil { + // If the request failed anyway, make sure we don't use this client. + p.clientGuard.Lock() + p.client = nil + p.clientGuard.Unlock() + } + + return resp, err +} + +// exchangeHTTPS creates an HTTP client and sends the DNS query using it. +func (p *dnsOverHTTPS) exchangeHTTPS(m *dns.Msg) (resp *dns.Msg, err error) { client, err := p.getClient() if err != nil { return nil, fmt.Errorf("initializing http client: %w", err) } logBegin(p.Address(), m) - r, err := p.exchangeHTTPSClient(m, client) + resp, err = p.exchangeHTTPSClient(m, client) logFinish(p.Address(), err) - return r, err + return resp, err } // exchangeHTTPSClient sends the DNS query to a DoH resolver using the specified @@ -125,16 +171,6 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient(m *dns.Msg, client *http.Client) (*dn defer resp.Body.Close() } if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - // If this is a timeout error, trying to forcibly re-create the HTTP - // client instance. - // - // See https://github.com/AdguardTeam/AdGuardHome/issues/3217. - p.clientGuard.Lock() - p.client = nil - p.clientGuard.Unlock() - } - return nil, fmt.Errorf("requesting %s: %w", p.boot.URL, err) } @@ -160,6 +196,38 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient(m *dns.Msg, client *http.Client) (*dn return &response, err } +// hasClient returns true if this connection already has an active HTTP client. +func (p *dnsOverHTTPS) hasClient() (ok bool) { + p.clientGuard.Lock() + defer p.clientGuard.Unlock() + + return p.client != nil +} + +// shouldRetry checks what error we have received and returns true if we should +// re-create the HTTP client and retry the request. +func (p *dnsOverHTTPS) shouldRetry(err error) (ok bool) { + if err == nil { + return false + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + // If this is a timeout error, trying to forcibly re-create the HTTP + // client instance. This is an attempt to fix an issue with DoH client + // stalling after a network change. + // + // See https://github.com/AdguardTeam/AdGuardHome/issues/3217. + return true + } + + if isQUICRetryError(err) { + return true + } + + return false +} + // getClient gets or lazily initializes an HTTP client (and transport) that will // be used for this DoH resolver. func (p *dnsOverHTTPS) getClient() (c *http.Client, err error) { @@ -266,7 +334,7 @@ func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) { func (p *dnsOverHTTPS) createTransportH3( tlsConfig *tls.Config, dialContext dialHandler, -) (roundTripper *http3.RoundTripper, err error) { +) (roundTripper http.RoundTripper, err error) { if !p.supportsH3() { return nil, errors.Error("HTTP3 support is not enabled") } @@ -276,21 +344,25 @@ func (p *dnsOverHTTPS) createTransportH3( return nil, err } - return &http3.RoundTripper{ + rt := &http3.RoundTripper{ Dial: func( ctx context.Context, + // Ignore the address and always connect to the one that we got // from the bootstrapper. _ string, tlsCfg *tls.Config, cfg *quic.Config, ) (c quic.EarlyConnection, err error) { - return quic.DialAddrEarlyContext(ctx, addr, tlsCfg, cfg) + c, err = quic.DialAddrEarlyContext(ctx, addr, tlsCfg, cfg) + return c, err }, DisableCompression: true, TLSClientConfig: tlsConfig, QuicConfig: p.quicConfig, - }, nil + } + + return rt, nil } // probeH3 runs a test to check whether QUIC is faster than TLS for this diff --git a/upstream/upstream_doh_test.go b/upstream/upstream_doh_test.go index 16a9a09f4..2e0075e3e 100644 --- a/upstream/upstream_doh_test.go +++ b/upstream/upstream_doh_test.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/http" + "strconv" "testing" "time" @@ -105,15 +106,88 @@ func TestUpstreamDoH(t *testing.T) { } } +func TestUpstreamDoH_serverRestart(t *testing.T) { + testCases := []struct { + name string + httpVersions []HTTPVersion + }{ + { + name: "http2", + httpVersions: []HTTPVersion{HTTPVersion11, HTTPVersion2}, + }, + { + name: "http3", + httpVersions: []HTTPVersion{HTTPVersion3}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Run the first server instance. + srv := startDoHServer(t, testDoHServerOptions{ + http3Enabled: true, + }) + + // Create a DNS-over-HTTPS upstream. + address := fmt.Sprintf("https://%s/dns-query", srv.addr) + u, err := AddressToUpstream( + address, + &Options{ + InsecureSkipVerify: true, + HTTPVersions: tc.httpVersions, + Timeout: time.Second, + }, + ) + require.NoError(t, err) + + // Test that the upstream works properly. + checkUpstream(t, u, address) + + // Now let's restart the server on the same address. + _, portStr, err := net.SplitHostPort(srv.addr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + + // Shutdown the first server. + srv.Shutdown() + + // Start the new one on the same port. + srv = startDoHServer(t, testDoHServerOptions{ + http3Enabled: true, + port: port, + }) + + // Check that everything works after restart. + checkUpstream(t, u, address) + + // Stop the server again. + srv.Shutdown() + + // Now try to send a message and make sure that it returns an error. + _, err = u.Exchange(createTestMessage()) + require.Error(t, err) + + // Start the server one more time. + srv = startDoHServer(t, testDoHServerOptions{ + http3Enabled: true, + port: port, + }) + + // Check that everything works after the second restart. + checkUpstream(t, u, address) + }) + } +} + // testDoHServerOptions allows customizing testDoHServer behavior. type testDoHServerOptions struct { http3Enabled bool delayHandshakeH2 time.Duration delayHandshakeH3 time.Duration + port int } -// testDoHServer is an instance of a test DNS-over-HTTPS server that we use -// for tests. +// testDoHServer is an instance of a test DNS-over-HTTPS server. type testDoHServer struct { // addr is the address that this server listens to. addr string @@ -126,9 +200,12 @@ type testDoHServer struct { // serverH3 is an HTTP/3 server. serverH3 *http3.Server + + // listenerH3 that's used to serve HTTP/3. + listenerH3 quic.EarlyListener } -// Shutdown stops the DOH server. +// Shutdown stops the DoH server. func (s *testDoHServer) Shutdown() { if s.server != nil { _ = s.server.Shutdown(context.Background()) @@ -136,6 +213,7 @@ func (s *testDoHServer) Shutdown() { if s.serverH3 != nil { _ = s.serverH3.Close() + _ = s.listenerH3.Close() } } @@ -156,7 +234,8 @@ func startDoHServer( } // Listen TCP first. - tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") + listenAddr := fmt.Sprintf("127.0.0.1:%d", opts.port) + tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr) require.NoError(t, err) tcpListen, err := net.ListenTCP("tcp", tcpAddr) @@ -179,6 +258,7 @@ func startDoHServer( tcpAddr = tcpListen.Addr().(*net.TCPAddr) var serverH3 *http3.Server + var listenerH3 quic.EarlyListener if opts.http3Enabled { tlsConfigH3 := tlsConfig.Clone() @@ -191,9 +271,7 @@ func startDoHServer( } serverH3 = &http3.Server{ - TLSConfig: tlsConfig.Clone(), - QuicConfig: &quic.Config{}, - Handler: handler, + Handler: handler, } // Listen UDP for the H3 server. Reuse the same port as was used for the @@ -201,17 +279,18 @@ func startDoHServer( udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", tcpAddr.Port)) require.NoError(t, err) - udpListen, err := net.ListenUDP("udp", udpAddr) + listenerH3, err = quic.ListenAddrEarly(udpAddr.String(), tlsConfigH3, &quic.Config{}) require.NoError(t, err) // Run the H3 server. - go serverH3.Serve(udpListen) + go serverH3.ServeListener(listenerH3) } return &testDoHServer{ - tlsConfig: tlsConfig, - server: server, - serverH3: serverH3, + tlsConfig: tlsConfig, + server: server, + serverH3: serverH3, + listenerH3: listenerH3, // Save the address that the server listens to. addr: tcpAddr.String(), } diff --git a/upstream/upstream_quic.go b/upstream/upstream_quic.go index eb98fd171..f8942def7 100644 --- a/upstream/upstream_quic.go +++ b/upstream/upstream_quic.go @@ -9,6 +9,7 @@ import ( "time" "github.com/AdguardTeam/dnsproxy/proxyutil" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/lucas-clemente/quic-go" "github.com/miekg/dns" @@ -22,9 +23,6 @@ const ( // an internal error and is incapable of pursuing the transaction or the // connection. QUICCodeInternalError = quic.ApplicationErrorCode(1) - // QUICCodeProtocolError signals that the DoQ implementation encountered - // a protocol error and is forcibly aborting the connection. - QUICCodeProtocolError = quic.ApplicationErrorCode(2) // QUICKeepAlivePeriod is the value that we pass to *quic.Config and that // controls the period with with keep-alive frames are being sent to the // connection. We set it to 20s as it would be in the quic-go@v0.27.1 with @@ -49,11 +47,12 @@ type dnsOverQUIC struct { // conn is the current active QUIC connection. It can be closed and // re-opened when needed. conn quic.Connection + // connGuard protects conn and quicConfig. + connGuard sync.RWMutex // bytesPool is a *sync.Pool we use to store byte buffers in. These byte // buffers are used to read responses from the upstream. - bytesPool *sync.Pool - // sync.RWMutex protects conn and bytesPool. - sync.RWMutex + bytesPool *sync.Pool + bytesPoolGuard sync.Mutex } // type check @@ -73,12 +72,7 @@ func newDoQ(uu *url.URL, opts *Options) (u Upstream, err error) { boot: b, quicConfig: &quic.Config{ KeepAlivePeriod: QUICKeepAlivePeriod, - // You can read more on address validation here: - // https://datatracker.ietf.org/doc/html/rfc9000#section-8.1 - // Setting maxOrigins to 1 and tokensPerOrigin to 10 assuming that - // this is more than enough for the way we use it (one connection - // per upstream). - TokenStore: quic.NewLRUTokenStore(1, 10), + TokenStore: newQUICTokenStore(), }, }, nil } @@ -87,25 +81,59 @@ func newDoQ(uu *url.URL, opts *Options) (u Upstream, err error) { func (p *dnsOverQUIC) Address() string { return p.boot.URL.String() } // Exchange implements the Upstream interface for *dnsOverQUIC. -func (p *dnsOverQUIC) Exchange(m *dns.Msg) (res *dns.Msg, err error) { - var conn quic.Connection - conn, err = p.getConnection(true) - if err != nil { - return nil, err - } - +func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { // When sending queries over a QUIC connection, the DNS Message ID MUST be // set to zero. id := m.Id m.Id = 0 defer func() { - // Restore the original ID to not break compatibility with proxies + // Restore the original ID to not break compatibility with proxies. m.Id = id - if res != nil { - res.Id = id + if resp != nil { + resp.Id = id } }() + // Check if there was already an active conn before sending the request. + // We'll only attempt to re-connect if there was one. + hasConnection := p.hasConnection() + + // Make the first attempt to send the DNS query. + resp, err = p.exchangeQUIC(m) + + // Make up to 2 attempts to re-open the QUIC connection and send the request + // again. There are several cases where this workaround is necessary to + // make DoQ usable. We need to make 2 attempts in the case when the + // connection was closed (due to inactivity for example) AND the server + // refuses to open a 0-RTT connection. + for i := 0; hasConnection && p.shouldRetry(err) && i < 2; i++ { + log.Debug("re-creating the QUIC connection and retrying due to %v", err) + + // Close the active connection to make sure we'll try to re-connect. + p.closeConnWithError(QUICCodeNoError) + + // Retry sending the request. + resp, err = p.exchangeQUIC(m) + } + + if err != nil { + // If we're unable to exchange messages, make sure the connection is + // closed and signal about an internal error. + p.closeConnWithError(QUICCodeInternalError) + } + + return resp, err +} + +// exchangeQUIC attempts to open a QUIC connection, send the DNS message +// through it and return the response it got from the server. +func (p *dnsOverQUIC) exchangeQUIC(m *dns.Msg) (resp *dns.Msg, err error) { + var conn quic.Connection + conn, err = p.getConnection(true) + if err != nil { + return nil, err + } + var buf []byte buf, err = m.Pack() if err != nil { @@ -115,36 +143,34 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (res *dns.Msg, err error) { var stream quic.Stream stream, err = p.openStream(conn) if err != nil { - p.closeConnWithError(QUICCodeInternalError) - return nil, fmt.Errorf("open new stream to %s: %w", p.Address(), err) + return nil, err } _, err = stream.Write(proxyutil.AddPrefix(buf)) if err != nil { - p.closeConnWithError(QUICCodeInternalError) return nil, fmt.Errorf("failed to write to a QUIC stream: %w", err) } // The client MUST send the DNS query over the selected stream, and MUST // indicate through the STREAM FIN mechanism that no further data will - // be sent on that stream. - // stream.Close() -- closes the write-direction of the stream. + // be sent on that stream. Note, that stream.Close() closes the + // write-direction of the stream, but does not prevent reading from it. _ = stream.Close() - res, err = p.readMsg(stream) - if err != nil { - // If a peer encounters such an error condition, it is considered a - // fatal error. It SHOULD forcibly abort the connection using QUIC's - // CONNECTION_CLOSE mechanism and SHOULD use the DoQ error code - // DOQ_PROTOCOL_ERROR. - p.closeConnWithError(QUICCodeProtocolError) - } - return res, err + return p.readMsg(stream) +} + +// shouldRetry checks what error we received and decides whether it is required +// to re-open the connection and retry sending the request. +func (p *dnsOverQUIC) shouldRetry(err error) (ok bool) { + return isQUICRetryError(err) } // getBytesPool returns (creates if needed) a pool we store byte buffers in. func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) { - p.Lock() + p.bytesPoolGuard.Lock() + defer p.bytesPoolGuard.Unlock() + if p.bytesPool == nil { p.bytesPool = &sync.Pool{ New: func() interface{} { @@ -154,7 +180,7 @@ func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) { }, } } - p.Unlock() + return p.bytesPool } @@ -164,59 +190,57 @@ func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) { // close the existing one if needed. func (p *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { var conn quic.Connection - p.RLock() + p.connGuard.RLock() conn = p.conn if conn != nil && useCached { - p.RUnlock() + p.connGuard.RUnlock() + return conn, nil } if conn != nil { // we're recreating the connection, let's create a new one. _ = conn.CloseWithError(QUICCodeNoError, "") } - p.RUnlock() + p.connGuard.RUnlock() - p.Lock() - defer p.Unlock() + p.connGuard.Lock() + defer p.connGuard.Unlock() var err error conn, err = p.openConnection() if err != nil { - // This does not look too nice, but QUIC (or maybe quic-go) doesn't - // seem stable enough. Maybe retransmissions aren't fully implemented - // in quic-go? Anyways, the simple solution is to make a second try when - // it fails to open the QUIC connection. - conn, err = p.openConnection() - if err != nil { - return nil, err - } + return nil, err } p.conn = conn + return conn, nil } +// hasConnection returns true if there's an active QUIC connection. +func (p *dnsOverQUIC) hasConnection() (ok bool) { + p.connGuard.Lock() + defer p.connGuard.Unlock() + + return p.conn != nil +} + // openStream opens a new QUIC stream for the specified connection. func (p *dnsOverQUIC) openStream(conn quic.Connection) (quic.Stream, error) { - ctx := context.Background() - - if p.boot.options.Timeout > 0 { - deadline := time.Now().Add(p.boot.options.Timeout) - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(context.Background(), deadline) - defer cancel() // avoid resource leak - } + ctx, cancel := p.boot.newContext() + defer cancel() stream, err := conn.OpenStreamSync(ctx) if err == nil { return stream, nil } - // try to recreate the connection. + // We can get here if the old QUIC connection is not valid anymore. We + // should try to re-create the connection again in this case. newConn, err := p.getConnection(false) if err != nil { return nil, err } - // open a new stream. + // Open a new stream. return newConn.OpenStreamSync(ctx) } @@ -244,7 +268,10 @@ func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { addr := udpConn.RemoteAddr().String() - conn, err = quic.DialAddrEarlyContext(context.Background(), addr, tlsConfig, p.quicConfig) + ctx, cancel := p.boot.newContext() + defer cancel() + + conn, err = quic.DialAddrEarlyContext(ctx, addr, tlsConfig, p.quicConfig) if err != nil { return nil, fmt.Errorf("opening quic connection to %s: %w", p.Address(), err) } @@ -256,8 +283,8 @@ func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { // new queries were processed in another connection. We can do that in the case // of a fatal error. func (p *dnsOverQUIC) closeConnWithError(code quic.ApplicationErrorCode) { - p.Lock() - defer p.Unlock() + p.connGuard.Lock() + defer p.connGuard.Unlock() if p.conn == nil { // Do nothing, there's no active conn anyways. @@ -269,6 +296,10 @@ func (p *dnsOverQUIC) closeConnWithError(code quic.ApplicationErrorCode) { log.Error("failed to close the conn: %v", err) } p.conn = nil + + // Re-create the token store to make sure we're not trying to use invalid + // tokens for 0-RTT. + p.quicConfig.TokenStore = newQUICTokenStore() } // readMsg reads the incoming DNS message from the QUIC stream. @@ -297,3 +328,51 @@ func (p *dnsOverQUIC) readMsg(stream quic.Stream) (m *dns.Msg, err error) { return m, nil } + +// newQUICTokenStore creates a new quic.TokenStore that is necessary to have +// in order to benefit from 0-RTT. +func newQUICTokenStore() (s quic.TokenStore) { + // You can read more on address validation here: + // https://datatracker.ietf.org/doc/html/rfc9000#section-8.1 + // Setting maxOrigins to 1 and tokensPerOrigin to 10 assuming that this is + // more than enough for the way we use it (one connection per upstream). + return quic.NewLRUTokenStore(1, 10) +} + +// isQUICRetryError checks the error and determines whether it may signal that +// we should re-create the QUIC connection. This requirement is caused by +// quic-go issues, see the comments inside this function. +// TODO(ameshkov): re-test when updating quic-go. +func isQUICRetryError(err error) (ok bool) { + var qAppErr *quic.ApplicationError + if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 0 { + // This error is often returned when the server has been restarted, + // and we try to use the same connection on the client-side. It seems, + // that the old connections aren't closed immediately on the server-side + // and that's why one can run into this. + // In addition to that, quic-go HTTP3 client implementation does not + // clean up dead connections (this one is specific to DoH3 upstream): + // https://github.com/lucas-clemente/quic-go/issues/765 + return true + } + + var qIdleErr *quic.IdleTimeoutError + if errors.As(err, &qIdleErr) { + // This error means that the connection was closed due to being idle. + // In this case we should forcibly re-create the QUIC connection. + // Reproducing is rather simple, stop the server and wait for 30 seconds + // then try to send another request via the same upstream. + return true + } + + if errors.Is(err, quic.Err0RTTRejected) { + // This error happens when we try to establish a 0-RTT connection with + // a token the server is no more aware of. This can be reproduced by + // restarting the QUIC server (it will clear its tokens cache). The + // next connection attempt will return this error until the client's + // tokens cache is purged. + return true + } + + return false +} diff --git a/upstream/upstream_quic_test.go b/upstream/upstream_quic_test.go index 0a2fe21f0..047d19b28 100644 --- a/upstream/upstream_quic_test.go +++ b/upstream/upstream_quic_test.go @@ -1,20 +1,33 @@ package upstream import ( + "context" "crypto/tls" + "encoding/binary" + "fmt" + "io" + "net" + "strconv" "testing" + "time" + "github.com/AdguardTeam/dnsproxy/proxyutil" + "github.com/AdguardTeam/golibs/log" "github.com/lucas-clemente/quic-go" + "github.com/miekg/dns" "github.com/stretchr/testify/require" ) func TestUpstreamDoQ(t *testing.T) { - // Create a DNS-over-QUIC upstream - address := "quic://dns.adguard.com" + srv := startDoQServer(t, 0) + t.Cleanup(srv.Shutdown) + + address := fmt.Sprintf("quic://%s", srv.addr) var lastState tls.ConnectionState u, err := AddressToUpstream( address, &Options{ + InsecureSkipVerify: true, VerifyConnection: func(state tls.ConnectionState) error { lastState = state return nil @@ -47,3 +60,150 @@ func TestUpstreamDoQ(t *testing.T) { // Make sure that the session has been resumed. require.True(t, lastState.DidResume) } + +func TestUpstreamDoQ_serverRestart(t *testing.T) { + // Run the first server instance. + srv := startDoQServer(t, 0) + + // Create a DNS-over-QUIC upstream. + address := fmt.Sprintf("quic://%s", srv.addr) + u, err := AddressToUpstream(address, &Options{InsecureSkipVerify: true, Timeout: time.Second}) + require.NoError(t, err) + + // Test that the upstream works properly. + checkUpstream(t, u, address) + + // Now let's restart the server on the same address. + _, portStr, err := net.SplitHostPort(srv.addr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + + // Shutdown the first server. + srv.Shutdown() + + // Start the new one on the same port. + srv = startDoQServer(t, port) + + // Check that everything works after restart. + checkUpstream(t, u, address) + + // Stop the server again. + srv.Shutdown() + + // Now try to send a message and make sure that it returns an error. + _, err = u.Exchange(createTestMessage()) + require.Error(t, err) + + // Start the server one more time. + srv = startDoQServer(t, port) + + // Check that everything works after the second restart. + checkUpstream(t, u, address) +} + +// testDoHServer is an instance of a test DNS-over-QUIC server. +type testDoQServer struct { + // addr is the address that this server listens to. + addr string + + // tlsConfig is the TLS configuration that is used for this server. + tlsConfig *tls.Config + + // listener is the QUIC connections listener. + listener quic.EarlyListener +} + +// Shutdown stops the test server. +func (s *testDoQServer) Shutdown() { + _ = s.listener.Close() +} + +// Serve serves DoQ requests. +func (s *testDoQServer) Serve() { + for { + conn, err := s.listener.Accept(context.Background()) + if err == quic.ErrServerClosed { + // Finish serving on ErrServerClosed error. + return + } + + if err != nil { + log.Debug("error while accepting a new connection: %v", err) + } + + go s.handleQUICConnection(conn) + } +} + +// handleQUICConnection handles incoming QUIC connection. +func (s *testDoQServer) handleQUICConnection(conn quic.EarlyConnection) { + for { + stream, err := conn.AcceptStream(context.Background()) + if err != nil { + _ = conn.CloseWithError(QUICCodeNoError, "") + + return + } + + go func() { + qErr := s.handleQUICStream(stream) + if qErr != nil { + _ = conn.CloseWithError(QUICCodeNoError, "") + } + }() + } +} + +// handleQUICStream handles new QUIC streams, reads DNS messages and responds to +// them. +func (s *testDoQServer) handleQUICStream(stream quic.Stream) (err error) { + defer stream.Close() + + buf := make([]byte, dns.MaxMsgSize+2) + _, err = stream.Read(buf) + if err != nil && err != io.EOF { + return err + } + + req := &dns.Msg{} + packetLen := binary.BigEndian.Uint16(buf[:2]) + err = req.Unpack(buf[2 : packetLen+2]) + if err != nil { + return err + } + + resp := respondToTestMessage(req) + + buf, err = resp.Pack() + if err != nil { + return err + } + + buf = proxyutil.AddPrefix(buf) + _, err = stream.Write(buf) + + return err +} + +// startDoQServer starts a test DoQ server. +func startDoQServer(t *testing.T, port int) (s *testDoQServer) { + tlsConfig := createServerTLSConfig(t, "127.0.0.1") + tlsConfig.NextProtos = []string{NextProtoDQ} + + listen, err := quic.ListenAddrEarly( + fmt.Sprintf("127.0.0.1:%d", port), + tlsConfig, + &quic.Config{}, + ) + require.NoError(t, err) + + s = &testDoQServer{ + addr: listen.Addr().String(), + tlsConfig: tlsConfig, + listener: listen, + } + + go s.Serve() + + return s +} diff --git a/vendor/github.com/bluele/gcache/LICENSE b/vendor/github.com/bluele/gcache/LICENSE new file mode 100644 index 000000000..d1e7b03e3 --- /dev/null +++ b/vendor/github.com/bluele/gcache/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017 Jun Kimura + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/bluele/gcache/README.md b/vendor/github.com/bluele/gcache/README.md new file mode 100644 index 000000000..b8f124b22 --- /dev/null +++ b/vendor/github.com/bluele/gcache/README.md @@ -0,0 +1,320 @@ +# GCache + +![Test](https://github.com/bluele/gcache/workflows/Test/badge.svg) +[![GoDoc](https://godoc.org/github.com/bluele/gcache?status.svg)](https://pkg.go.dev/github.com/bluele/gcache?tab=doc) + +Cache library for golang. It supports expirable Cache, LFU, LRU and ARC. + +## Features + +* Supports expirable Cache, LFU, LRU and ARC. + +* Goroutine safe. + +* Supports event handlers which evict, purge, and add entry. (Optional) + +* Automatically load cache if it doesn't exists. (Optional) + +## Install + +``` +$ go get github.com/bluele/gcache +``` + +## Example + +### Manually set a key-value pair. + +```go +package main + +import ( + "github.com/bluele/gcache" + "fmt" +) + +func main() { + gc := gcache.New(20). + LRU(). + Build() + gc.Set("key", "ok") + value, err := gc.Get("key") + if err != nil { + panic(err) + } + fmt.Println("Get:", value) +} +``` + +``` +Get: ok +``` + +### Manually set a key-value pair, with an expiration time. + +```go +package main + +import ( + "github.com/bluele/gcache" + "fmt" + "time" +) + +func main() { + gc := gcache.New(20). + LRU(). + Build() + gc.SetWithExpire("key", "ok", time.Second*10) + value, _ := gc.Get("key") + fmt.Println("Get:", value) + + // Wait for value to expire + time.Sleep(time.Second*10) + + value, err = gc.Get("key") + if err != nil { + panic(err) + } + fmt.Println("Get:", value) +} +``` + +``` +Get: ok +// 10 seconds later, new attempt: +panic: ErrKeyNotFound +``` + + +### Automatically load value + +```go +package main + +import ( + "github.com/bluele/gcache" + "fmt" +) + +func main() { + gc := gcache.New(20). + LRU(). + LoaderFunc(func(key interface{}) (interface{}, error) { + return "ok", nil + }). + Build() + value, err := gc.Get("key") + if err != nil { + panic(err) + } + fmt.Println("Get:", value) +} +``` + +``` +Get: ok +``` + +### Automatically load value with expiration + +```go +package main + +import ( + "fmt" + "time" + + "github.com/bluele/gcache" +) + +func main() { + var evictCounter, loaderCounter, purgeCounter int + gc := gcache.New(20). + LRU(). + LoaderExpireFunc(func(key interface{}) (interface{}, *time.Duration, error) { + loaderCounter++ + expire := 1 * time.Second + return "ok", &expire, nil + }). + EvictedFunc(func(key, value interface{}) { + evictCounter++ + fmt.Println("evicted key:", key) + }). + PurgeVisitorFunc(func(key, value interface{}) { + purgeCounter++ + fmt.Println("purged key:", key) + }). + Build() + value, err := gc.Get("key") + if err != nil { + panic(err) + } + fmt.Println("Get:", value) + time.Sleep(1 * time.Second) + value, err = gc.Get("key") + if err != nil { + panic(err) + } + fmt.Println("Get:", value) + gc.Purge() + if loaderCounter != evictCounter+purgeCounter { + panic("bad") + } +} +``` + +``` +Get: ok +evicted key: key +Get: ok +purged key: key +``` + + +## Cache Algorithm + + * Least-Frequently Used (LFU) + + Discards the least frequently used items first. + + ```go + func main() { + // size: 10 + gc := gcache.New(10). + LFU(). + Build() + gc.Set("key", "value") + } + ``` + + * Least Recently Used (LRU) + + Discards the least recently used items first. + + ```go + func main() { + // size: 10 + gc := gcache.New(10). + LRU(). + Build() + gc.Set("key", "value") + } + ``` + + * Adaptive Replacement Cache (ARC) + + Constantly balances between LRU and LFU, to improve the combined result. + + detail: http://en.wikipedia.org/wiki/Adaptive_replacement_cache + + ```go + func main() { + // size: 10 + gc := gcache.New(10). + ARC(). + Build() + gc.Set("key", "value") + } + ``` + + * SimpleCache (Default) + + SimpleCache has no clear priority for evict cache. It depends on key-value map order. + + ```go + func main() { + // size: 10 + gc := gcache.New(10).Build() + gc.Set("key", "value") + v, err := gc.Get("key") + if err != nil { + panic(err) + } + } + ``` + +## Loading Cache + +If specified `LoaderFunc`, values are automatically loaded by the cache, and are stored in the cache until either evicted or manually invalidated. + +```go +func main() { + gc := gcache.New(10). + LRU(). + LoaderFunc(func(key interface{}) (interface{}, error) { + return "value", nil + }). + Build() + v, _ := gc.Get("key") + // output: "value" + fmt.Println(v) +} +``` + +GCache coordinates cache fills such that only one load in one process of an entire replicated set of processes populates the cache, then multiplexes the loaded value to all callers. + +## Expirable cache + +```go +func main() { + // LRU cache, size: 10, expiration: after a hour + gc := gcache.New(10). + LRU(). + Expiration(time.Hour). + Build() +} +``` + +## Event handlers + +### Evicted handler + +Event handler for evict the entry. + +```go +func main() { + gc := gcache.New(2). + EvictedFunc(func(key, value interface{}) { + fmt.Println("evicted key:", key) + }). + Build() + for i := 0; i < 3; i++ { + gc.Set(i, i*i) + } +} +``` + +``` +evicted key: 0 +``` + +### Added handler + +Event handler for add the entry. + +```go +func main() { + gc := gcache.New(2). + AddedFunc(func(key, value interface{}) { + fmt.Println("added key:", key) + }). + Build() + for i := 0; i < 3; i++ { + gc.Set(i, i*i) + } +} +``` + +``` +added key: 0 +added key: 1 +added key: 2 +``` + +# Author + +**Jun Kimura** + +* +* diff --git a/vendor/github.com/bluele/gcache/arc.go b/vendor/github.com/bluele/gcache/arc.go new file mode 100644 index 000000000..e2015e911 --- /dev/null +++ b/vendor/github.com/bluele/gcache/arc.go @@ -0,0 +1,456 @@ +package gcache + +import ( + "container/list" + "time" +) + +// Constantly balances between LRU and LFU, to improve the combined result. +type ARC struct { + baseCache + items map[interface{}]*arcItem + + part int + t1 *arcList + t2 *arcList + b1 *arcList + b2 *arcList +} + +func newARC(cb *CacheBuilder) *ARC { + c := &ARC{} + buildCache(&c.baseCache, cb) + + c.init() + c.loadGroup.cache = c + return c +} + +func (c *ARC) init() { + c.items = make(map[interface{}]*arcItem) + c.t1 = newARCList() + c.t2 = newARCList() + c.b1 = newARCList() + c.b2 = newARCList() +} + +func (c *ARC) replace(key interface{}) { + if !c.isCacheFull() { + return + } + var old interface{} + if c.t1.Len() > 0 && ((c.b2.Has(key) && c.t1.Len() == c.part) || (c.t1.Len() > c.part)) { + old = c.t1.RemoveTail() + c.b1.PushFront(old) + } else if c.t2.Len() > 0 { + old = c.t2.RemoveTail() + c.b2.PushFront(old) + } else { + old = c.t1.RemoveTail() + c.b1.PushFront(old) + } + item, ok := c.items[old] + if ok { + delete(c.items, old) + if c.evictedFunc != nil { + c.evictedFunc(item.key, item.value) + } + } +} + +func (c *ARC) Set(key, value interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + _, err := c.set(key, value) + return err +} + +// Set a new key-value pair with an expiration time +func (c *ARC) SetWithExpire(key, value interface{}, expiration time.Duration) error { + c.mu.Lock() + defer c.mu.Unlock() + item, err := c.set(key, value) + if err != nil { + return err + } + + t := c.clock.Now().Add(expiration) + item.(*arcItem).expiration = &t + return nil +} + +func (c *ARC) set(key, value interface{}) (interface{}, error) { + var err error + if c.serializeFunc != nil { + value, err = c.serializeFunc(key, value) + if err != nil { + return nil, err + } + } + + item, ok := c.items[key] + if ok { + item.value = value + } else { + item = &arcItem{ + clock: c.clock, + key: key, + value: value, + } + c.items[key] = item + } + + if c.expiration != nil { + t := c.clock.Now().Add(*c.expiration) + item.expiration = &t + } + + defer func() { + if c.addedFunc != nil { + c.addedFunc(key, value) + } + }() + + if c.t1.Has(key) || c.t2.Has(key) { + return item, nil + } + + if elt := c.b1.Lookup(key); elt != nil { + c.setPart(minInt(c.size, c.part+maxInt(c.b2.Len()/c.b1.Len(), 1))) + c.replace(key) + c.b1.Remove(key, elt) + c.t2.PushFront(key) + return item, nil + } + + if elt := c.b2.Lookup(key); elt != nil { + c.setPart(maxInt(0, c.part-maxInt(c.b1.Len()/c.b2.Len(), 1))) + c.replace(key) + c.b2.Remove(key, elt) + c.t2.PushFront(key) + return item, nil + } + + if c.isCacheFull() && c.t1.Len()+c.b1.Len() == c.size { + if c.t1.Len() < c.size { + c.b1.RemoveTail() + c.replace(key) + } else { + pop := c.t1.RemoveTail() + item, ok := c.items[pop] + if ok { + delete(c.items, pop) + if c.evictedFunc != nil { + c.evictedFunc(item.key, item.value) + } + } + } + } else { + total := c.t1.Len() + c.b1.Len() + c.t2.Len() + c.b2.Len() + if total >= c.size { + if total == (2 * c.size) { + if c.b2.Len() > 0 { + c.b2.RemoveTail() + } else { + c.b1.RemoveTail() + } + } + c.replace(key) + } + } + c.t1.PushFront(key) + return item, nil +} + +// Get a value from cache pool using key if it exists. If not exists and it has LoaderFunc, it will generate the value using you have specified LoaderFunc method returns value. +func (c *ARC) Get(key interface{}) (interface{}, error) { + v, err := c.get(key, false) + if err == KeyNotFoundError { + return c.getWithLoader(key, true) + } + return v, err +} + +// GetIFPresent gets a value from cache pool using key if it exists. +// If it dose not exists key, returns KeyNotFoundError. +// And send a request which refresh value for specified key if cache object has LoaderFunc. +func (c *ARC) GetIFPresent(key interface{}) (interface{}, error) { + v, err := c.get(key, false) + if err == KeyNotFoundError { + return c.getWithLoader(key, false) + } + return v, err +} + +func (c *ARC) get(key interface{}, onLoad bool) (interface{}, error) { + v, err := c.getValue(key, onLoad) + if err != nil { + return nil, err + } + if c.deserializeFunc != nil { + return c.deserializeFunc(key, v) + } + return v, nil +} + +func (c *ARC) getValue(key interface{}, onLoad bool) (interface{}, error) { + c.mu.Lock() + defer c.mu.Unlock() + if elt := c.t1.Lookup(key); elt != nil { + c.t1.Remove(key, elt) + item := c.items[key] + if !item.IsExpired(nil) { + c.t2.PushFront(key) + if !onLoad { + c.stats.IncrHitCount() + } + return item.value, nil + } else { + delete(c.items, key) + c.b1.PushFront(key) + if c.evictedFunc != nil { + c.evictedFunc(item.key, item.value) + } + } + } + if elt := c.t2.Lookup(key); elt != nil { + item := c.items[key] + if !item.IsExpired(nil) { + c.t2.MoveToFront(elt) + if !onLoad { + c.stats.IncrHitCount() + } + return item.value, nil + } else { + delete(c.items, key) + c.t2.Remove(key, elt) + c.b2.PushFront(key) + if c.evictedFunc != nil { + c.evictedFunc(item.key, item.value) + } + } + } + + if !onLoad { + c.stats.IncrMissCount() + } + return nil, KeyNotFoundError +} + +func (c *ARC) getWithLoader(key interface{}, isWait bool) (interface{}, error) { + if c.loaderExpireFunc == nil { + return nil, KeyNotFoundError + } + value, _, err := c.load(key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { + if e != nil { + return nil, e + } + c.mu.Lock() + defer c.mu.Unlock() + item, err := c.set(key, v) + if err != nil { + return nil, err + } + if expiration != nil { + t := c.clock.Now().Add(*expiration) + item.(*arcItem).expiration = &t + } + return v, nil + }, isWait) + if err != nil { + return nil, err + } + return value, nil +} + +// Has checks if key exists in cache +func (c *ARC) Has(key interface{}) bool { + c.mu.RLock() + defer c.mu.RUnlock() + now := time.Now() + return c.has(key, &now) +} + +func (c *ARC) has(key interface{}, now *time.Time) bool { + item, ok := c.items[key] + if !ok { + return false + } + return !item.IsExpired(now) +} + +// Remove removes the provided key from the cache. +func (c *ARC) Remove(key interface{}) bool { + c.mu.Lock() + defer c.mu.Unlock() + + return c.remove(key) +} + +func (c *ARC) remove(key interface{}) bool { + if elt := c.t1.Lookup(key); elt != nil { + c.t1.Remove(key, elt) + item := c.items[key] + delete(c.items, key) + c.b1.PushFront(key) + if c.evictedFunc != nil { + c.evictedFunc(key, item.value) + } + return true + } + + if elt := c.t2.Lookup(key); elt != nil { + c.t2.Remove(key, elt) + item := c.items[key] + delete(c.items, key) + c.b2.PushFront(key) + if c.evictedFunc != nil { + c.evictedFunc(key, item.value) + } + return true + } + + return false +} + +// GetALL returns all key-value pairs in the cache. +func (c *ARC) GetALL(checkExpired bool) map[interface{}]interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + items := make(map[interface{}]interface{}, len(c.items)) + now := time.Now() + for k, item := range c.items { + if !checkExpired || c.has(k, &now) { + items[k] = item.value + } + } + return items +} + +// Keys returns a slice of the keys in the cache. +func (c *ARC) Keys(checkExpired bool) []interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + keys := make([]interface{}, 0, len(c.items)) + now := time.Now() + for k := range c.items { + if !checkExpired || c.has(k, &now) { + keys = append(keys, k) + } + } + return keys +} + +// Len returns the number of items in the cache. +func (c *ARC) Len(checkExpired bool) int { + c.mu.RLock() + defer c.mu.RUnlock() + if !checkExpired { + return len(c.items) + } + var length int + now := time.Now() + for k := range c.items { + if c.has(k, &now) { + length++ + } + } + return length +} + +// Purge is used to completely clear the cache +func (c *ARC) Purge() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.purgeVisitorFunc != nil { + for _, item := range c.items { + c.purgeVisitorFunc(item.key, item.value) + } + } + + c.init() +} + +func (c *ARC) setPart(p int) { + if c.isCacheFull() { + c.part = p + } +} + +func (c *ARC) isCacheFull() bool { + return (c.t1.Len() + c.t2.Len()) == c.size +} + +// IsExpired returns boolean value whether this item is expired or not. +func (it *arcItem) IsExpired(now *time.Time) bool { + if it.expiration == nil { + return false + } + if now == nil { + t := it.clock.Now() + now = &t + } + return it.expiration.Before(*now) +} + +type arcList struct { + l *list.List + keys map[interface{}]*list.Element +} + +type arcItem struct { + clock Clock + key interface{} + value interface{} + expiration *time.Time +} + +func newARCList() *arcList { + return &arcList{ + l: list.New(), + keys: make(map[interface{}]*list.Element), + } +} + +func (al *arcList) Has(key interface{}) bool { + _, ok := al.keys[key] + return ok +} + +func (al *arcList) Lookup(key interface{}) *list.Element { + elt := al.keys[key] + return elt +} + +func (al *arcList) MoveToFront(elt *list.Element) { + al.l.MoveToFront(elt) +} + +func (al *arcList) PushFront(key interface{}) { + if elt, ok := al.keys[key]; ok { + al.l.MoveToFront(elt) + return + } + elt := al.l.PushFront(key) + al.keys[key] = elt +} + +func (al *arcList) Remove(key interface{}, elt *list.Element) { + delete(al.keys, key) + al.l.Remove(elt) +} + +func (al *arcList) RemoveTail() interface{} { + elt := al.l.Back() + al.l.Remove(elt) + + key := elt.Value + delete(al.keys, key) + + return key +} + +func (al *arcList) Len() int { + return al.l.Len() +} diff --git a/vendor/github.com/bluele/gcache/cache.go b/vendor/github.com/bluele/gcache/cache.go new file mode 100644 index 000000000..e13e6f1cb --- /dev/null +++ b/vendor/github.com/bluele/gcache/cache.go @@ -0,0 +1,205 @@ +package gcache + +import ( + "errors" + "fmt" + "sync" + "time" +) + +const ( + TYPE_SIMPLE = "simple" + TYPE_LRU = "lru" + TYPE_LFU = "lfu" + TYPE_ARC = "arc" +) + +var KeyNotFoundError = errors.New("Key not found.") + +type Cache interface { + Set(key, value interface{}) error + SetWithExpire(key, value interface{}, expiration time.Duration) error + Get(key interface{}) (interface{}, error) + GetIFPresent(key interface{}) (interface{}, error) + GetALL(checkExpired bool) map[interface{}]interface{} + get(key interface{}, onLoad bool) (interface{}, error) + Remove(key interface{}) bool + Purge() + Keys(checkExpired bool) []interface{} + Len(checkExpired bool) int + Has(key interface{}) bool + + statsAccessor +} + +type baseCache struct { + clock Clock + size int + loaderExpireFunc LoaderExpireFunc + evictedFunc EvictedFunc + purgeVisitorFunc PurgeVisitorFunc + addedFunc AddedFunc + deserializeFunc DeserializeFunc + serializeFunc SerializeFunc + expiration *time.Duration + mu sync.RWMutex + loadGroup Group + *stats +} + +type ( + LoaderFunc func(interface{}) (interface{}, error) + LoaderExpireFunc func(interface{}) (interface{}, *time.Duration, error) + EvictedFunc func(interface{}, interface{}) + PurgeVisitorFunc func(interface{}, interface{}) + AddedFunc func(interface{}, interface{}) + DeserializeFunc func(interface{}, interface{}) (interface{}, error) + SerializeFunc func(interface{}, interface{}) (interface{}, error) +) + +type CacheBuilder struct { + clock Clock + tp string + size int + loaderExpireFunc LoaderExpireFunc + evictedFunc EvictedFunc + purgeVisitorFunc PurgeVisitorFunc + addedFunc AddedFunc + expiration *time.Duration + deserializeFunc DeserializeFunc + serializeFunc SerializeFunc +} + +func New(size int) *CacheBuilder { + return &CacheBuilder{ + clock: NewRealClock(), + tp: TYPE_SIMPLE, + size: size, + } +} + +func (cb *CacheBuilder) Clock(clock Clock) *CacheBuilder { + cb.clock = clock + return cb +} + +// Set a loader function. +// loaderFunc: create a new value with this function if cached value is expired. +func (cb *CacheBuilder) LoaderFunc(loaderFunc LoaderFunc) *CacheBuilder { + cb.loaderExpireFunc = func(k interface{}) (interface{}, *time.Duration, error) { + v, err := loaderFunc(k) + return v, nil, err + } + return cb +} + +// Set a loader function with expiration. +// loaderExpireFunc: create a new value with this function if cached value is expired. +// If nil returned instead of time.Duration from loaderExpireFunc than value will never expire. +func (cb *CacheBuilder) LoaderExpireFunc(loaderExpireFunc LoaderExpireFunc) *CacheBuilder { + cb.loaderExpireFunc = loaderExpireFunc + return cb +} + +func (cb *CacheBuilder) EvictType(tp string) *CacheBuilder { + cb.tp = tp + return cb +} + +func (cb *CacheBuilder) Simple() *CacheBuilder { + return cb.EvictType(TYPE_SIMPLE) +} + +func (cb *CacheBuilder) LRU() *CacheBuilder { + return cb.EvictType(TYPE_LRU) +} + +func (cb *CacheBuilder) LFU() *CacheBuilder { + return cb.EvictType(TYPE_LFU) +} + +func (cb *CacheBuilder) ARC() *CacheBuilder { + return cb.EvictType(TYPE_ARC) +} + +func (cb *CacheBuilder) EvictedFunc(evictedFunc EvictedFunc) *CacheBuilder { + cb.evictedFunc = evictedFunc + return cb +} + +func (cb *CacheBuilder) PurgeVisitorFunc(purgeVisitorFunc PurgeVisitorFunc) *CacheBuilder { + cb.purgeVisitorFunc = purgeVisitorFunc + return cb +} + +func (cb *CacheBuilder) AddedFunc(addedFunc AddedFunc) *CacheBuilder { + cb.addedFunc = addedFunc + return cb +} + +func (cb *CacheBuilder) DeserializeFunc(deserializeFunc DeserializeFunc) *CacheBuilder { + cb.deserializeFunc = deserializeFunc + return cb +} + +func (cb *CacheBuilder) SerializeFunc(serializeFunc SerializeFunc) *CacheBuilder { + cb.serializeFunc = serializeFunc + return cb +} + +func (cb *CacheBuilder) Expiration(expiration time.Duration) *CacheBuilder { + cb.expiration = &expiration + return cb +} + +func (cb *CacheBuilder) Build() Cache { + if cb.size <= 0 && cb.tp != TYPE_SIMPLE { + panic("gcache: Cache size <= 0") + } + + return cb.build() +} + +func (cb *CacheBuilder) build() Cache { + switch cb.tp { + case TYPE_SIMPLE: + return newSimpleCache(cb) + case TYPE_LRU: + return newLRUCache(cb) + case TYPE_LFU: + return newLFUCache(cb) + case TYPE_ARC: + return newARC(cb) + default: + panic("gcache: Unknown type " + cb.tp) + } +} + +func buildCache(c *baseCache, cb *CacheBuilder) { + c.clock = cb.clock + c.size = cb.size + c.loaderExpireFunc = cb.loaderExpireFunc + c.expiration = cb.expiration + c.addedFunc = cb.addedFunc + c.deserializeFunc = cb.deserializeFunc + c.serializeFunc = cb.serializeFunc + c.evictedFunc = cb.evictedFunc + c.purgeVisitorFunc = cb.purgeVisitorFunc + c.stats = &stats{} +} + +// load a new value using by specified key. +func (c *baseCache) load(key interface{}, cb func(interface{}, *time.Duration, error) (interface{}, error), isWait bool) (interface{}, bool, error) { + v, called, err := c.loadGroup.Do(key, func() (v interface{}, e error) { + defer func() { + if r := recover(); r != nil { + e = fmt.Errorf("Loader panics: %v", r) + } + }() + return cb(c.loaderExpireFunc(key)) + }, isWait) + if err != nil { + return nil, called, err + } + return v, called, nil +} diff --git a/vendor/github.com/bluele/gcache/clock.go b/vendor/github.com/bluele/gcache/clock.go new file mode 100644 index 000000000..3acc3f0db --- /dev/null +++ b/vendor/github.com/bluele/gcache/clock.go @@ -0,0 +1,53 @@ +package gcache + +import ( + "sync" + "time" +) + +type Clock interface { + Now() time.Time +} + +type RealClock struct{} + +func NewRealClock() Clock { + return RealClock{} +} + +func (rc RealClock) Now() time.Time { + t := time.Now() + return t +} + +type FakeClock interface { + Clock + + Advance(d time.Duration) +} + +func NewFakeClock() FakeClock { + return &fakeclock{ + // Taken from github.com/jonboulle/clockwork: use a fixture that does not fulfill Time.IsZero() + now: time.Date(1984, time.April, 4, 0, 0, 0, 0, time.UTC), + } +} + +type fakeclock struct { + now time.Time + + mutex sync.RWMutex +} + +func (fc *fakeclock) Now() time.Time { + fc.mutex.RLock() + defer fc.mutex.RUnlock() + t := fc.now + return t +} + +func (fc *fakeclock) Advance(d time.Duration) { + fc.mutex.Lock() + defer fc.mutex.Unlock() + fc.now = fc.now.Add(d) +} diff --git a/vendor/github.com/bluele/gcache/lfu.go b/vendor/github.com/bluele/gcache/lfu.go new file mode 100644 index 000000000..9a4e3dfeb --- /dev/null +++ b/vendor/github.com/bluele/gcache/lfu.go @@ -0,0 +1,377 @@ +package gcache + +import ( + "container/list" + "time" +) + +// Discards the least frequently used items first. +type LFUCache struct { + baseCache + items map[interface{}]*lfuItem + freqList *list.List // list for freqEntry +} + +var _ Cache = (*LFUCache)(nil) + +type lfuItem struct { + clock Clock + key interface{} + value interface{} + freqElement *list.Element + expiration *time.Time +} + +type freqEntry struct { + freq uint + items map[*lfuItem]struct{} +} + +func newLFUCache(cb *CacheBuilder) *LFUCache { + c := &LFUCache{} + buildCache(&c.baseCache, cb) + + c.init() + c.loadGroup.cache = c + return c +} + +func (c *LFUCache) init() { + c.freqList = list.New() + c.items = make(map[interface{}]*lfuItem, c.size) + c.freqList.PushFront(&freqEntry{ + freq: 0, + items: make(map[*lfuItem]struct{}), + }) +} + +// Set a new key-value pair +func (c *LFUCache) Set(key, value interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + _, err := c.set(key, value) + return err +} + +// Set a new key-value pair with an expiration time +func (c *LFUCache) SetWithExpire(key, value interface{}, expiration time.Duration) error { + c.mu.Lock() + defer c.mu.Unlock() + item, err := c.set(key, value) + if err != nil { + return err + } + + t := c.clock.Now().Add(expiration) + item.(*lfuItem).expiration = &t + return nil +} + +func (c *LFUCache) set(key, value interface{}) (interface{}, error) { + var err error + if c.serializeFunc != nil { + value, err = c.serializeFunc(key, value) + if err != nil { + return nil, err + } + } + + // Check for existing item + item, ok := c.items[key] + if ok { + item.value = value + } else { + // Verify size not exceeded + if len(c.items) >= c.size { + c.evict(1) + } + item = &lfuItem{ + clock: c.clock, + key: key, + value: value, + freqElement: nil, + } + el := c.freqList.Front() + fe := el.Value.(*freqEntry) + fe.items[item] = struct{}{} + + item.freqElement = el + c.items[key] = item + } + + if c.expiration != nil { + t := c.clock.Now().Add(*c.expiration) + item.expiration = &t + } + + if c.addedFunc != nil { + c.addedFunc(key, value) + } + + return item, nil +} + +// Get a value from cache pool using key if it exists. +// If it dose not exists key and has LoaderFunc, +// generate a value using `LoaderFunc` method returns value. +func (c *LFUCache) Get(key interface{}) (interface{}, error) { + v, err := c.get(key, false) + if err == KeyNotFoundError { + return c.getWithLoader(key, true) + } + return v, err +} + +// GetIFPresent gets a value from cache pool using key if it exists. +// If it dose not exists key, returns KeyNotFoundError. +// And send a request which refresh value for specified key if cache object has LoaderFunc. +func (c *LFUCache) GetIFPresent(key interface{}) (interface{}, error) { + v, err := c.get(key, false) + if err == KeyNotFoundError { + return c.getWithLoader(key, false) + } + return v, err +} + +func (c *LFUCache) get(key interface{}, onLoad bool) (interface{}, error) { + v, err := c.getValue(key, onLoad) + if err != nil { + return nil, err + } + if c.deserializeFunc != nil { + return c.deserializeFunc(key, v) + } + return v, nil +} + +func (c *LFUCache) getValue(key interface{}, onLoad bool) (interface{}, error) { + c.mu.Lock() + item, ok := c.items[key] + if ok { + if !item.IsExpired(nil) { + c.increment(item) + v := item.value + c.mu.Unlock() + if !onLoad { + c.stats.IncrHitCount() + } + return v, nil + } + c.removeItem(item) + } + c.mu.Unlock() + if !onLoad { + c.stats.IncrMissCount() + } + return nil, KeyNotFoundError +} + +func (c *LFUCache) getWithLoader(key interface{}, isWait bool) (interface{}, error) { + if c.loaderExpireFunc == nil { + return nil, KeyNotFoundError + } + value, _, err := c.load(key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { + if e != nil { + return nil, e + } + c.mu.Lock() + defer c.mu.Unlock() + item, err := c.set(key, v) + if err != nil { + return nil, err + } + if expiration != nil { + t := c.clock.Now().Add(*expiration) + item.(*lfuItem).expiration = &t + } + return v, nil + }, isWait) + if err != nil { + return nil, err + } + return value, nil +} + +func (c *LFUCache) increment(item *lfuItem) { + currentFreqElement := item.freqElement + currentFreqEntry := currentFreqElement.Value.(*freqEntry) + nextFreq := currentFreqEntry.freq + 1 + delete(currentFreqEntry.items, item) + + // a boolean whether reuse the empty current entry + removable := isRemovableFreqEntry(currentFreqEntry) + + // insert item into a valid entry + nextFreqElement := currentFreqElement.Next() + switch { + case nextFreqElement == nil || nextFreqElement.Value.(*freqEntry).freq > nextFreq: + if removable { + currentFreqEntry.freq = nextFreq + nextFreqElement = currentFreqElement + } else { + nextFreqElement = c.freqList.InsertAfter(&freqEntry{ + freq: nextFreq, + items: make(map[*lfuItem]struct{}), + }, currentFreqElement) + } + case nextFreqElement.Value.(*freqEntry).freq == nextFreq: + if removable { + c.freqList.Remove(currentFreqElement) + } + default: + panic("unreachable") + } + nextFreqElement.Value.(*freqEntry).items[item] = struct{}{} + item.freqElement = nextFreqElement +} + +// evict removes the least frequence item from the cache. +func (c *LFUCache) evict(count int) { + entry := c.freqList.Front() + for i := 0; i < count; { + if entry == nil { + return + } else { + for item := range entry.Value.(*freqEntry).items { + if i >= count { + return + } + c.removeItem(item) + i++ + } + entry = entry.Next() + } + } +} + +// Has checks if key exists in cache +func (c *LFUCache) Has(key interface{}) bool { + c.mu.RLock() + defer c.mu.RUnlock() + now := time.Now() + return c.has(key, &now) +} + +func (c *LFUCache) has(key interface{}, now *time.Time) bool { + item, ok := c.items[key] + if !ok { + return false + } + return !item.IsExpired(now) +} + +// Remove removes the provided key from the cache. +func (c *LFUCache) Remove(key interface{}) bool { + c.mu.Lock() + defer c.mu.Unlock() + + return c.remove(key) +} + +func (c *LFUCache) remove(key interface{}) bool { + if item, ok := c.items[key]; ok { + c.removeItem(item) + return true + } + return false +} + +// removeElement is used to remove a given list element from the cache +func (c *LFUCache) removeItem(item *lfuItem) { + entry := item.freqElement.Value.(*freqEntry) + delete(c.items, item.key) + delete(entry.items, item) + if isRemovableFreqEntry(entry) { + c.freqList.Remove(item.freqElement) + } + if c.evictedFunc != nil { + c.evictedFunc(item.key, item.value) + } +} + +func (c *LFUCache) keys() []interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + keys := make([]interface{}, len(c.items)) + var i = 0 + for k := range c.items { + keys[i] = k + i++ + } + return keys +} + +// GetALL returns all key-value pairs in the cache. +func (c *LFUCache) GetALL(checkExpired bool) map[interface{}]interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + items := make(map[interface{}]interface{}, len(c.items)) + now := time.Now() + for k, item := range c.items { + if !checkExpired || c.has(k, &now) { + items[k] = item.value + } + } + return items +} + +// Keys returns a slice of the keys in the cache. +func (c *LFUCache) Keys(checkExpired bool) []interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + keys := make([]interface{}, 0, len(c.items)) + now := time.Now() + for k := range c.items { + if !checkExpired || c.has(k, &now) { + keys = append(keys, k) + } + } + return keys +} + +// Len returns the number of items in the cache. +func (c *LFUCache) Len(checkExpired bool) int { + c.mu.RLock() + defer c.mu.RUnlock() + if !checkExpired { + return len(c.items) + } + var length int + now := time.Now() + for k := range c.items { + if c.has(k, &now) { + length++ + } + } + return length +} + +// Completely clear the cache +func (c *LFUCache) Purge() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.purgeVisitorFunc != nil { + for key, item := range c.items { + c.purgeVisitorFunc(key, item.value) + } + } + + c.init() +} + +// IsExpired returns boolean value whether this item is expired or not. +func (it *lfuItem) IsExpired(now *time.Time) bool { + if it.expiration == nil { + return false + } + if now == nil { + t := it.clock.Now() + now = &t + } + return it.expiration.Before(*now) +} + +func isRemovableFreqEntry(entry *freqEntry) bool { + return entry.freq != 0 && len(entry.items) == 0 +} diff --git a/vendor/github.com/bluele/gcache/lru.go b/vendor/github.com/bluele/gcache/lru.go new file mode 100644 index 000000000..a85d66039 --- /dev/null +++ b/vendor/github.com/bluele/gcache/lru.go @@ -0,0 +1,317 @@ +package gcache + +import ( + "container/list" + "time" +) + +// Discards the least recently used items first. +type LRUCache struct { + baseCache + items map[interface{}]*list.Element + evictList *list.List +} + +func newLRUCache(cb *CacheBuilder) *LRUCache { + c := &LRUCache{} + buildCache(&c.baseCache, cb) + + c.init() + c.loadGroup.cache = c + return c +} + +func (c *LRUCache) init() { + c.evictList = list.New() + c.items = make(map[interface{}]*list.Element, c.size+1) +} + +func (c *LRUCache) set(key, value interface{}) (interface{}, error) { + var err error + if c.serializeFunc != nil { + value, err = c.serializeFunc(key, value) + if err != nil { + return nil, err + } + } + + // Check for existing item + var item *lruItem + if it, ok := c.items[key]; ok { + c.evictList.MoveToFront(it) + item = it.Value.(*lruItem) + item.value = value + } else { + // Verify size not exceeded + if c.evictList.Len() >= c.size { + c.evict(1) + } + item = &lruItem{ + clock: c.clock, + key: key, + value: value, + } + c.items[key] = c.evictList.PushFront(item) + } + + if c.expiration != nil { + t := c.clock.Now().Add(*c.expiration) + item.expiration = &t + } + + if c.addedFunc != nil { + c.addedFunc(key, value) + } + + return item, nil +} + +// set a new key-value pair +func (c *LRUCache) Set(key, value interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + _, err := c.set(key, value) + return err +} + +// Set a new key-value pair with an expiration time +func (c *LRUCache) SetWithExpire(key, value interface{}, expiration time.Duration) error { + c.mu.Lock() + defer c.mu.Unlock() + item, err := c.set(key, value) + if err != nil { + return err + } + + t := c.clock.Now().Add(expiration) + item.(*lruItem).expiration = &t + return nil +} + +// Get a value from cache pool using key if it exists. +// If it dose not exists key and has LoaderFunc, +// generate a value using `LoaderFunc` method returns value. +func (c *LRUCache) Get(key interface{}) (interface{}, error) { + v, err := c.get(key, false) + if err == KeyNotFoundError { + return c.getWithLoader(key, true) + } + return v, err +} + +// GetIFPresent gets a value from cache pool using key if it exists. +// If it dose not exists key, returns KeyNotFoundError. +// And send a request which refresh value for specified key if cache object has LoaderFunc. +func (c *LRUCache) GetIFPresent(key interface{}) (interface{}, error) { + v, err := c.get(key, false) + if err == KeyNotFoundError { + return c.getWithLoader(key, false) + } + return v, err +} + +func (c *LRUCache) get(key interface{}, onLoad bool) (interface{}, error) { + v, err := c.getValue(key, onLoad) + if err != nil { + return nil, err + } + if c.deserializeFunc != nil { + return c.deserializeFunc(key, v) + } + return v, nil +} + +func (c *LRUCache) getValue(key interface{}, onLoad bool) (interface{}, error) { + c.mu.Lock() + item, ok := c.items[key] + if ok { + it := item.Value.(*lruItem) + if !it.IsExpired(nil) { + c.evictList.MoveToFront(item) + v := it.value + c.mu.Unlock() + if !onLoad { + c.stats.IncrHitCount() + } + return v, nil + } + c.removeElement(item) + } + c.mu.Unlock() + if !onLoad { + c.stats.IncrMissCount() + } + return nil, KeyNotFoundError +} + +func (c *LRUCache) getWithLoader(key interface{}, isWait bool) (interface{}, error) { + if c.loaderExpireFunc == nil { + return nil, KeyNotFoundError + } + value, _, err := c.load(key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { + if e != nil { + return nil, e + } + c.mu.Lock() + defer c.mu.Unlock() + item, err := c.set(key, v) + if err != nil { + return nil, err + } + if expiration != nil { + t := c.clock.Now().Add(*expiration) + item.(*lruItem).expiration = &t + } + return v, nil + }, isWait) + if err != nil { + return nil, err + } + return value, nil +} + +// evict removes the oldest item from the cache. +func (c *LRUCache) evict(count int) { + for i := 0; i < count; i++ { + ent := c.evictList.Back() + if ent == nil { + return + } else { + c.removeElement(ent) + } + } +} + +// Has checks if key exists in cache +func (c *LRUCache) Has(key interface{}) bool { + c.mu.RLock() + defer c.mu.RUnlock() + now := time.Now() + return c.has(key, &now) +} + +func (c *LRUCache) has(key interface{}, now *time.Time) bool { + item, ok := c.items[key] + if !ok { + return false + } + return !item.Value.(*lruItem).IsExpired(now) +} + +// Remove removes the provided key from the cache. +func (c *LRUCache) Remove(key interface{}) bool { + c.mu.Lock() + defer c.mu.Unlock() + + return c.remove(key) +} + +func (c *LRUCache) remove(key interface{}) bool { + if ent, ok := c.items[key]; ok { + c.removeElement(ent) + return true + } + return false +} + +func (c *LRUCache) removeElement(e *list.Element) { + c.evictList.Remove(e) + entry := e.Value.(*lruItem) + delete(c.items, entry.key) + if c.evictedFunc != nil { + entry := e.Value.(*lruItem) + c.evictedFunc(entry.key, entry.value) + } +} + +func (c *LRUCache) keys() []interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + keys := make([]interface{}, len(c.items)) + var i = 0 + for k := range c.items { + keys[i] = k + i++ + } + return keys +} + +// GetALL returns all key-value pairs in the cache. +func (c *LRUCache) GetALL(checkExpired bool) map[interface{}]interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + items := make(map[interface{}]interface{}, len(c.items)) + now := time.Now() + for k, item := range c.items { + if !checkExpired || c.has(k, &now) { + items[k] = item.Value.(*lruItem).value + } + } + return items +} + +// Keys returns a slice of the keys in the cache. +func (c *LRUCache) Keys(checkExpired bool) []interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + keys := make([]interface{}, 0, len(c.items)) + now := time.Now() + for k := range c.items { + if !checkExpired || c.has(k, &now) { + keys = append(keys, k) + } + } + return keys +} + +// Len returns the number of items in the cache. +func (c *LRUCache) Len(checkExpired bool) int { + c.mu.RLock() + defer c.mu.RUnlock() + if !checkExpired { + return len(c.items) + } + var length int + now := time.Now() + for k := range c.items { + if c.has(k, &now) { + length++ + } + } + return length +} + +// Completely clear the cache +func (c *LRUCache) Purge() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.purgeVisitorFunc != nil { + for key, item := range c.items { + it := item.Value.(*lruItem) + v := it.value + c.purgeVisitorFunc(key, v) + } + } + + c.init() +} + +type lruItem struct { + clock Clock + key interface{} + value interface{} + expiration *time.Time +} + +// IsExpired returns boolean value whether this item is expired or not. +func (it *lruItem) IsExpired(now *time.Time) bool { + if it.expiration == nil { + return false + } + if now == nil { + t := it.clock.Now() + now = &t + } + return it.expiration.Before(*now) +} diff --git a/vendor/github.com/bluele/gcache/simple.go b/vendor/github.com/bluele/gcache/simple.go new file mode 100644 index 000000000..7310af141 --- /dev/null +++ b/vendor/github.com/bluele/gcache/simple.go @@ -0,0 +1,307 @@ +package gcache + +import ( + "time" +) + +// SimpleCache has no clear priority for evict cache. It depends on key-value map order. +type SimpleCache struct { + baseCache + items map[interface{}]*simpleItem +} + +func newSimpleCache(cb *CacheBuilder) *SimpleCache { + c := &SimpleCache{} + buildCache(&c.baseCache, cb) + + c.init() + c.loadGroup.cache = c + return c +} + +func (c *SimpleCache) init() { + if c.size <= 0 { + c.items = make(map[interface{}]*simpleItem) + } else { + c.items = make(map[interface{}]*simpleItem, c.size) + } +} + +// Set a new key-value pair +func (c *SimpleCache) Set(key, value interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + _, err := c.set(key, value) + return err +} + +// Set a new key-value pair with an expiration time +func (c *SimpleCache) SetWithExpire(key, value interface{}, expiration time.Duration) error { + c.mu.Lock() + defer c.mu.Unlock() + item, err := c.set(key, value) + if err != nil { + return err + } + + t := c.clock.Now().Add(expiration) + item.(*simpleItem).expiration = &t + return nil +} + +func (c *SimpleCache) set(key, value interface{}) (interface{}, error) { + var err error + if c.serializeFunc != nil { + value, err = c.serializeFunc(key, value) + if err != nil { + return nil, err + } + } + + // Check for existing item + item, ok := c.items[key] + if ok { + item.value = value + } else { + // Verify size not exceeded + if (len(c.items) >= c.size) && c.size > 0 { + c.evict(1) + } + item = &simpleItem{ + clock: c.clock, + value: value, + } + c.items[key] = item + } + + if c.expiration != nil { + t := c.clock.Now().Add(*c.expiration) + item.expiration = &t + } + + if c.addedFunc != nil { + c.addedFunc(key, value) + } + + return item, nil +} + +// Get a value from cache pool using key if it exists. +// If it dose not exists key and has LoaderFunc, +// generate a value using `LoaderFunc` method returns value. +func (c *SimpleCache) Get(key interface{}) (interface{}, error) { + v, err := c.get(key, false) + if err == KeyNotFoundError { + return c.getWithLoader(key, true) + } + return v, err +} + +// GetIFPresent gets a value from cache pool using key if it exists. +// If it dose not exists key, returns KeyNotFoundError. +// And send a request which refresh value for specified key if cache object has LoaderFunc. +func (c *SimpleCache) GetIFPresent(key interface{}) (interface{}, error) { + v, err := c.get(key, false) + if err == KeyNotFoundError { + return c.getWithLoader(key, false) + } + return v, nil +} + +func (c *SimpleCache) get(key interface{}, onLoad bool) (interface{}, error) { + v, err := c.getValue(key, onLoad) + if err != nil { + return nil, err + } + if c.deserializeFunc != nil { + return c.deserializeFunc(key, v) + } + return v, nil +} + +func (c *SimpleCache) getValue(key interface{}, onLoad bool) (interface{}, error) { + c.mu.Lock() + item, ok := c.items[key] + if ok { + if !item.IsExpired(nil) { + v := item.value + c.mu.Unlock() + if !onLoad { + c.stats.IncrHitCount() + } + return v, nil + } + c.remove(key) + } + c.mu.Unlock() + if !onLoad { + c.stats.IncrMissCount() + } + return nil, KeyNotFoundError +} + +func (c *SimpleCache) getWithLoader(key interface{}, isWait bool) (interface{}, error) { + if c.loaderExpireFunc == nil { + return nil, KeyNotFoundError + } + value, _, err := c.load(key, func(v interface{}, expiration *time.Duration, e error) (interface{}, error) { + if e != nil { + return nil, e + } + c.mu.Lock() + defer c.mu.Unlock() + item, err := c.set(key, v) + if err != nil { + return nil, err + } + if expiration != nil { + t := c.clock.Now().Add(*expiration) + item.(*simpleItem).expiration = &t + } + return v, nil + }, isWait) + if err != nil { + return nil, err + } + return value, nil +} + +func (c *SimpleCache) evict(count int) { + now := c.clock.Now() + current := 0 + for key, item := range c.items { + if current >= count { + return + } + if item.expiration == nil || now.After(*item.expiration) { + defer c.remove(key) + current++ + } + } +} + +// Has checks if key exists in cache +func (c *SimpleCache) Has(key interface{}) bool { + c.mu.RLock() + defer c.mu.RUnlock() + now := time.Now() + return c.has(key, &now) +} + +func (c *SimpleCache) has(key interface{}, now *time.Time) bool { + item, ok := c.items[key] + if !ok { + return false + } + return !item.IsExpired(now) +} + +// Remove removes the provided key from the cache. +func (c *SimpleCache) Remove(key interface{}) bool { + c.mu.Lock() + defer c.mu.Unlock() + + return c.remove(key) +} + +func (c *SimpleCache) remove(key interface{}) bool { + item, ok := c.items[key] + if ok { + delete(c.items, key) + if c.evictedFunc != nil { + c.evictedFunc(key, item.value) + } + return true + } + return false +} + +// Returns a slice of the keys in the cache. +func (c *SimpleCache) keys() []interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + keys := make([]interface{}, len(c.items)) + var i = 0 + for k := range c.items { + keys[i] = k + i++ + } + return keys +} + +// GetALL returns all key-value pairs in the cache. +func (c *SimpleCache) GetALL(checkExpired bool) map[interface{}]interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + items := make(map[interface{}]interface{}, len(c.items)) + now := time.Now() + for k, item := range c.items { + if !checkExpired || c.has(k, &now) { + items[k] = item.value + } + } + return items +} + +// Keys returns a slice of the keys in the cache. +func (c *SimpleCache) Keys(checkExpired bool) []interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + keys := make([]interface{}, 0, len(c.items)) + now := time.Now() + for k := range c.items { + if !checkExpired || c.has(k, &now) { + keys = append(keys, k) + } + } + return keys +} + +// Len returns the number of items in the cache. +func (c *SimpleCache) Len(checkExpired bool) int { + c.mu.RLock() + defer c.mu.RUnlock() + if !checkExpired { + return len(c.items) + } + var length int + now := time.Now() + for k := range c.items { + if c.has(k, &now) { + length++ + } + } + return length +} + +// Completely clear the cache +func (c *SimpleCache) Purge() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.purgeVisitorFunc != nil { + for key, item := range c.items { + c.purgeVisitorFunc(key, item.value) + } + } + + c.init() +} + +type simpleItem struct { + clock Clock + value interface{} + expiration *time.Time +} + +// IsExpired returns boolean value whether this item is expired or not. +func (si *simpleItem) IsExpired(now *time.Time) bool { + if si.expiration == nil { + return false + } + if now == nil { + t := si.clock.Now() + now = &t + } + return si.expiration.Before(*now) +} diff --git a/vendor/github.com/bluele/gcache/singleflight.go b/vendor/github.com/bluele/gcache/singleflight.go new file mode 100644 index 000000000..2c6285e82 --- /dev/null +++ b/vendor/github.com/bluele/gcache/singleflight.go @@ -0,0 +1,82 @@ +package gcache + +/* +Copyright 2012 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// This module provides a duplicate function call suppression +// mechanism. + +import "sync" + +// call is an in-flight or completed Do call +type call struct { + wg sync.WaitGroup + val interface{} + err error +} + +// Group represents a class of work and forms a namespace in which +// units of work can be executed with duplicate suppression. +type Group struct { + cache Cache + mu sync.Mutex // protects m + m map[interface{}]*call // lazily initialized +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +func (g *Group) Do(key interface{}, fn func() (interface{}, error), isWait bool) (interface{}, bool, error) { + g.mu.Lock() + v, err := g.cache.get(key, true) + if err == nil { + g.mu.Unlock() + return v, false, nil + } + if g.m == nil { + g.m = make(map[interface{}]*call) + } + if c, ok := g.m[key]; ok { + g.mu.Unlock() + if !isWait { + return nil, false, KeyNotFoundError + } + c.wg.Wait() + return c.val, false, c.err + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + if !isWait { + go g.call(c, key, fn) + return nil, false, KeyNotFoundError + } + v, err = g.call(c, key, fn) + return v, true, err +} + +func (g *Group) call(c *call, key interface{}, fn func() (interface{}, error)) (interface{}, error) { + c.val, c.err = fn() + c.wg.Done() + + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() + + return c.val, c.err +} diff --git a/vendor/github.com/bluele/gcache/stats.go b/vendor/github.com/bluele/gcache/stats.go new file mode 100644 index 000000000..ca0bf3185 --- /dev/null +++ b/vendor/github.com/bluele/gcache/stats.go @@ -0,0 +1,53 @@ +package gcache + +import ( + "sync/atomic" +) + +type statsAccessor interface { + HitCount() uint64 + MissCount() uint64 + LookupCount() uint64 + HitRate() float64 +} + +// statistics +type stats struct { + hitCount uint64 + missCount uint64 +} + +// increment hit count +func (st *stats) IncrHitCount() uint64 { + return atomic.AddUint64(&st.hitCount, 1) +} + +// increment miss count +func (st *stats) IncrMissCount() uint64 { + return atomic.AddUint64(&st.missCount, 1) +} + +// HitCount returns hit count +func (st *stats) HitCount() uint64 { + return atomic.LoadUint64(&st.hitCount) +} + +// MissCount returns miss count +func (st *stats) MissCount() uint64 { + return atomic.LoadUint64(&st.missCount) +} + +// LookupCount returns lookup count +func (st *stats) LookupCount() uint64 { + return st.HitCount() + st.MissCount() +} + +// HitRate returns rate for cache hitting +func (st *stats) HitRate() float64 { + hc, mc := st.HitCount(), st.MissCount() + total := hc + mc + if total == 0 { + return 0.0 + } + return float64(hc) / float64(total) +} diff --git a/vendor/github.com/bluele/gcache/utils.go b/vendor/github.com/bluele/gcache/utils.go new file mode 100644 index 000000000..1f784e4c4 --- /dev/null +++ b/vendor/github.com/bluele/gcache/utils.go @@ -0,0 +1,15 @@ +package gcache + +func minInt(x, y int) int { + if x < y { + return x + } + return y +} + +func maxInt(x, y int) int { + if x > y { + return x + } + return y +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 170d03040..967138791 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -22,6 +22,9 @@ github.com/ameshkov/dnsstamps # github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 ## explicit github.com/beefsack/go-rate +# github.com/bluele/gcache v0.0.2 +## explicit; go 1.15 +github.com/bluele/gcache # github.com/davecgh/go-spew v1.1.1 ## explicit github.com/davecgh/go-spew/spew