From 95ef855f83c328636e7e529b130cfe50affd9d16 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Tue, 18 Oct 2022 15:11:00 +0300 Subject: [PATCH] Pull request: Upstream now implements io.Closer. Merge in DNS/dnsproxy from upstream_closer to master Squashed commit of the following: commit 3ac92bce285c9fa910ebe2ca1b213ca040784a98 Author: Andrey Meshkov Date: Mon Oct 17 23:53:22 2022 +0300 fix formatting commit 3c749a8b8c890f2cd4c7ab609cc137fae037a6b1 Author: Andrey Meshkov Date: Mon Oct 17 23:48:04 2022 +0300 Upstream now implements io.Closer. This is rather important because some of the Upstream implementations actually require explicit cleanup. However, there's a lot of old code that is not aware of the fact that Upstream can be cleaned up. In order to make the life easier for the authors, I used runtime.SetFinalizer where possible to guarantee cleanup. --- fastip/fastest_test.go | 14 +++- proxy/proxy.go | 4 + proxy/proxy_test.go | 115 ++++++++++++++++------------- proxy/upstreams.go | 35 +++++++-- upstream/parallel_test.go | 16 +++- upstream/upstream.go | 2 + upstream/upstream_dnscrypt.go | 6 ++ upstream/upstream_dnscrypt_test.go | 27 ++++--- upstream/upstream_doh.go | 31 +++++++- upstream/upstream_doh_test.go | 4 + upstream/upstream_dot.go | 65 +++++++++------- upstream/upstream_plain.go | 6 ++ upstream/upstream_plain_test.go | 20 ++--- upstream/upstream_pool.go | 78 ++++++++++++------- upstream/upstream_pool_test.go | 58 ++++++--------- upstream/upstream_quic.go | 47 ++++++++---- upstream/upstream_quic_test.go | 4 + upstream/upstream_test.go | 14 +++- 18 files changed, 351 insertions(+), 195 deletions(-) diff --git a/fastip/fastest_test.go b/fastip/fastest_test.go index 2f437a181..3270808eb 100644 --- a/fastip/fastest_test.go +++ b/fastip/fastest_test.go @@ -81,7 +81,7 @@ type errUpstream struct { err error } -func (u errUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) { +func (u errUpstream) Exchange(_ *dns.Msg) (*dns.Msg, error) { return nil, u.err } @@ -89,6 +89,10 @@ type testAUpstream struct { recs []*dns.A } +// type check +var _ upstream.Upstream = (*testAUpstream)(nil) + +// Exchange implements the upstream.Upstream interface for *testAUpstream. func (u *testAUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { resp = &dns.Msg{} resp.SetReply(m) @@ -100,10 +104,16 @@ func (u *testAUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return resp, nil } -func (u *testAUpstream) Address() string { +// Address implements the upstream.Upstream interface for *testAUpstream. +func (u *testAUpstream) Address() (addr string) { return "" } +// Close implements the upstream.Upstream interface for *testAUpstream. +func (u *testAUpstream) Close() (err error) { + return nil +} + func (u *testAUpstream) add(host string, ip net.IP) (chain *testAUpstream) { u.recs = append(u.recs, &dns.A{ Hdr: dns.RR_Header{ diff --git a/proxy/proxy.go b/proxy/proxy.go index efb91dd22..7398304f9 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -261,6 +261,10 @@ func (p *Proxy) Stop() error { closeAll(p.dnsCryptTCPListen, &errs) p.dnsCryptTCPListen = nil + if p.UpstreamConfig != nil { + closeAll([]io.Closer{p.UpstreamConfig}, &errs) + } + p.started = false log.Println("Stopped the DNS proxy server") if len(errs) > 0 { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 6596d1319..3960810b5 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -82,6 +82,10 @@ type testDNSSECUpstream struct { rrsig dns.RR } +// type check +var _ upstream.Upstream = (*testDNSSECUpstream)(nil) + +// Exchange implements the upstream.Upstream interface for *testDNSSECUpstream. func (u *testDNSSECUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { resp = &dns.Msg{} resp.SetReply(m) @@ -113,10 +117,16 @@ func (u *testDNSSECUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return resp, nil } +// Address implements the upstream.Upstream interface for *testDNSSECUpstream. func (u *testDNSSECUpstream) Address() string { return "" } +// Close implements the upstream.Upstream interface for *testDNSSECUpstream. +func (u *testDNSSECUpstream) Close() (err error) { + return nil +} + func TestProxy_Resolve_dnssecCache(t *testing.T) { const host = "example.com" @@ -347,90 +357,71 @@ func TestUpstreamsSort(t *testing.T) { func TestExchangeWithReservedDomains(t *testing.T) { dnsProxy := createTestProxy(t, nil) - // upstreams specification. Domains adguard.com and google.ru reserved with fake upstreams, maps.google.ru excluded from dnsmasq. - upstreams := []string{"[/adguard.com/]1.2.3.4", "[/google.ru/]2.3.4.5", "[/maps.google.ru/]#", "1.1.1.1"} + // Upstreams specification. Domains adguard.com and google.ru reserved + // with fake upstreams, maps.google.ru excluded from dnsmasq. + upstreams := []string{ + "[/adguard.com/]1.2.3.4", + "[/google.ru/]2.3.4.5", + "[/maps.google.ru/]#", + "1.1.1.1", + } config, err := ParseUpstreamsConfig( upstreams, &upstream.Options{ InsecureSkipVerify: false, Bootstrap: []string{"8.8.8.8"}, Timeout: 1 * time.Second, - }) - if err != nil { - t.Fatalf("Error while upstream config parsing: %s", err) - } + }, + ) + require.NoError(t, err) + dnsProxy.UpstreamConfig = config err = dnsProxy.Start() - if err != nil { - t.Fatalf("cannot start the DNS proxy: %s", err) - } + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop) - // create a DNS-over-TCP client connection + // Create a DNS-over-TCP client connection. addr := dnsProxy.Addr(ProtoTCP) conn, err := dns.Dial("tcp", addr.String()) - if err != nil { - t.Fatalf("cannot connect to the proxy: %s", err) - } + require.NoError(t, err) - // create google-a test message + // Create google-a test message. req := createTestMessage() err = conn.WriteMsg(req) - if err != nil { - t.Fatalf("cannot write message: %s", err) - } + require.NoError(t, err) - // make sure if dnsproxy is working + // Make sure that dnsproxy is working. res, err := conn.ReadMsg() - if err != nil { - t.Fatalf("cannot read response to message: %s", err) - } + require.NoError(t, err) requireResponse(t, req, res) - // create adguard.com test message + // Create adguard.com test message. req = createHostTestMessage("adguard.com") err = conn.WriteMsg(req) - if err != nil { - t.Fatalf("cannot write message: %s", err) - } + require.NoError(t, err) - // test message should not be resolved + // Test message should not be resolved. res, _ = conn.ReadMsg() - if res.Answer != nil { - t.Fatal("adguard.com should not be resolved") - } + require.Nil(t, res.Answer) - // create www.google.ru test message + // Create www.google.ru test message. req = createHostTestMessage("www.google.ru") err = conn.WriteMsg(req) - if err != nil { - t.Fatalf("cannot write message: %s", err) - } + require.NoError(t, err) - // test message should not be resolved + // Test message should not be resolved. res, _ = conn.ReadMsg() - if res.Answer != nil { - t.Fatal("www.google.ru should not be resolved") - } + require.Nil(t, res.Answer) - // create maps.google.ru test message + // Create maps.google.ru test message. req = createHostTestMessage("maps.google.ru") err = conn.WriteMsg(req) - if err != nil { - t.Fatalf("cannot write message: %s", err) - } + require.NoError(t, err) - // test message should be resolved + // Test message should be resolved. res, _ = conn.ReadMsg() - if res.Answer == nil { - t.Fatal("maps.google.ru should be resolved") - } - - // Stop the proxy - err = dnsProxy.Stop() - if err != nil { - t.Fatalf("cannot stop the DNS proxy: %s", err) - } + require.NotNil(t, res.Answer) } // TestOneByOneUpstreamsExchange tries to resolve DNS request @@ -757,6 +748,9 @@ type funcUpstream struct { addressFunc func() (addr string) } +// type check +var _ upstream.Upstream = (*funcUpstream)(nil) + // Exchange implements upstream.Upstream interface for *funcUpstream. func (wu *funcUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) { if wu.exchangeFunc == nil { @@ -767,7 +761,7 @@ func (wu *funcUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) { } // Address implements upstream.Upstream interface for *funcUpstream. -func (wu *funcUpstream) Address() string { +func (wu *funcUpstream) Address() (addr string) { if wu.addressFunc == nil { return "stub" } @@ -775,6 +769,11 @@ func (wu *funcUpstream) Address() string { return wu.addressFunc() } +// Close implements upstream.Upstream interface for *funcUpstream. +func (wu *funcUpstream) Close() (err error) { + return nil +} + func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) { dnsProxy := createTestProxy(t, nil) require.NoError(t, dnsProxy.Start()) @@ -1289,6 +1288,10 @@ type testUpstream struct { ecsReqMask int } +// type check +var _ upstream.Upstream = (*testUpstream)(nil) + +// Exchange implements the upstream.Upstream interface for *testUpstream. func (u *testUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { resp = &dns.Msg{} resp.SetReply(m) @@ -1309,10 +1312,16 @@ func (u *testUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return resp, nil } -func (u *testUpstream) Address() string { +// Address implements the upstream.Upstream interface for *testUpstream. +func (u *testUpstream) Address() (addr string) { return "" } +// Close implements the upstream.Upstream interface for *testUpstream. +func (u *testUpstream) Close() (err error) { + return nil +} + func TestProxy_Resolve_withOptimisticResolver(t *testing.T) { const ( host = "some.domain.name." diff --git a/proxy/upstreams.go b/proxy/upstreams.go index f85255a5e..97c550f09 100644 --- a/proxy/upstreams.go +++ b/proxy/upstreams.go @@ -2,13 +2,14 @@ package proxy import ( "fmt" + "io" "strings" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" - - "github.com/AdguardTeam/dnsproxy/upstream" ) // UpstreamConfig is a wrapper for list of default upstreams and map of reserved domains and corresponding upstreams @@ -19,6 +20,9 @@ type UpstreamConfig struct { SubdomainExclusions *stringutil.Set // set of domains with sub-domains exclusions } +// type check +var _ io.Closer = (*UpstreamConfig)(nil) + // ParseUpstreamsConfig returns UpstreamConfig and error if upstreams configuration is invalid // default upstream syntax: // reserved upstream syntax: [/domain1/../domainN/] @@ -159,12 +163,15 @@ func parseUpstreamLine(l string) (string, []string, error) { return u, hosts, nil } -// getUpstreamsForDomain looks for a domain in reserved domains map and returns a list of corresponding upstreams. -// returns default upstreams list if domain isn't found. More specific domains take priority over less specific domains. -// For example, map contains the following keys: host.com and www.host.com -// If we are looking for domain mail.host.com, this method will return value of host.com key -// If we are looking for domain www.host.com, this method will return value of www.host.com key -// If more specific domain value is nil, it means that domain was excluded and should be exchanged with default upstreams +// getUpstreamsForDomain looks for a domain in the reserved domains map and +// returns a list of corresponding upstreams. It returns default upstreams list +// if the domain was not found in the map. More specific domains take priority +// over less specific domains. For example, take a map that contains the +// following keys: host.com and www.host.com. If we are looking for domain +// mail.host.com, this method will return value of host.com key. If we are +// looking for domain www.host.com, this method will return value of the +// www.host.com key. If a more specific domain value is nil, it means that the +// domain was excluded and should be exchanged with default upstreams. func (uc *UpstreamConfig) getUpstreamsForDomain(host string) (ups []upstream.Upstream) { if len(uc.DomainReservedUpstreams) == 0 { return uc.Upstreams @@ -214,3 +221,15 @@ func (uc *UpstreamConfig) getUpstreamsForDomain(host string) (ups []upstream.Ups return uc.Upstreams } + +// Close implements the io.Closer interface for *UpstreamConfig. +func (uc *UpstreamConfig) Close() (err error) { + closeErrs := []error{} + closeAll(uc.Upstreams, &closeErrs) + + if len(closeErrs) > 0 { + return errors.List("failed to close some upstreams", closeErrs...) + } + + return nil +} diff --git a/upstream/parallel_test.go b/upstream/parallel_test.go index 6a7488360..8e8ba9579 100644 --- a/upstream/parallel_test.go +++ b/upstream/parallel_test.go @@ -106,7 +106,11 @@ type testUpstream struct { sleep time.Duration // a delay before response } -func (u *testUpstream) Exchange(req *dns.Msg) (*dns.Msg, error) { +// type check +var _ Upstream = (*testUpstream)(nil) + +// Exchange implements the Upstream interface for *testUpstream. +func (u *testUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { if u.sleep != 0 { time.Sleep(u.sleep) } @@ -115,7 +119,7 @@ func (u *testUpstream) Exchange(req *dns.Msg) (*dns.Msg, error) { return nil, nil } - resp := &dns.Msg{} + resp = &dns.Msg{} resp.SetReply(req) if len(u.a) != 0 { @@ -131,10 +135,16 @@ func (u *testUpstream) Exchange(req *dns.Msg) (*dns.Msg, error) { return resp, nil } -func (u *testUpstream) Address() string { +// Address implements the Upstream interface for *testUpstream. +func (u *testUpstream) Address() (addr string) { return "" } +// Close implements the Upstream interface for *testUpstream. +func (u *testUpstream) Close() (err error) { + return nil +} + func TestExchangeAll(t *testing.T) { u1 := testUpstream{} u1.a = net.ParseIP("1.1.1.1") diff --git a/upstream/upstream.go b/upstream/upstream.go index 586526d67..38f564763 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io" "net" "net/url" "strconv" @@ -26,6 +27,7 @@ type Upstream interface { Exchange(m *dns.Msg) (*dns.Msg, error) // Address returns the address of the upstream DNS resolver. Address() string + io.Closer } // Options for AddressToUpstream func. With these options we can configure the diff --git a/upstream/upstream_dnscrypt.go b/upstream/upstream_dnscrypt.go index d0a0c04d5..233a14f9c 100644 --- a/upstream/upstream_dnscrypt.go +++ b/upstream/upstream_dnscrypt.go @@ -49,6 +49,12 @@ func (p *dnsCrypt) Exchange(m *dns.Msg) (*dns.Msg, error) { return reply, err } +// Close implements the Upstream interface for *dnsCrypt. +func (p *dnsCrypt) Close() (err error) { + // Nothing to close here. + return nil +} + // exchangeDNSCrypt attempts to send the DNS query and returns the response func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (reply *dns.Msg, err error) { p.RLock() diff --git a/upstream/upstream_dnscrypt_test.go b/upstream/upstream_dnscrypt_test.go index 2e367f3bb..3d1944caf 100644 --- a/upstream/upstream_dnscrypt_test.go +++ b/upstream/upstream_dnscrypt_test.go @@ -4,16 +4,18 @@ import ( "net" "testing" + "github.com/AdguardTeam/golibs/testutil" "github.com/ameshkov/dnscrypt/v2" "github.com/miekg/dns" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUpstreamDNSCrypt(t *testing.T) { // AdGuard DNS (DNSCrypt) address := "sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20" u, err := AddressToUpstream(address, &Options{Timeout: dialTimeout}) - assert.Nil(t, err) + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) // Test that it responds properly for i := 0; i < 10; i++ { @@ -24,10 +26,10 @@ func TestUpstreamDNSCrypt(t *testing.T) { func TestDNSCryptTruncated(t *testing.T) { // Prepare the test DNSCrypt server config rc, err := dnscrypt.GenerateResolverConfig("example.org", nil) - assert.Nil(t, err) + require.NoError(t, err) cert, err := rc.CreateCert() - assert.Nil(t, err) + require.NoError(t, err) s := &dnscrypt.Server{ ProviderName: rc.ProviderName, @@ -37,14 +39,14 @@ func TestDNSCryptTruncated(t *testing.T) { // Prepare TCP listener tcpConn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4zero, Port: 0}) - assert.Nil(t, err) - defer tcpConn.Close() + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, tcpConn.Close) // Prepare UDP listener - on the same port port := tcpConn.Addr().(*net.TCPAddr).Port udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: port}) - assert.Nil(t, err) - defer udpConn.Close() + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, udpConn.Close) // Start the server go s.ServeUDP(udpConn) @@ -52,9 +54,10 @@ func TestDNSCryptTruncated(t *testing.T) { // Now prepare a client for this test server stamp, err := rc.CreateStamp(udpConn.LocalAddr().String()) - assert.Nil(t, err) + require.NoError(t, err) u, err := AddressToUpstream(stamp.String(), &Options{Timeout: timeout}) - assert.Nil(t, err) + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) req := new(dns.Msg) req.SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT) @@ -62,8 +65,8 @@ func TestDNSCryptTruncated(t *testing.T) { // Check that response is not truncated (even though it's huge) res, err := u.Exchange(req) - assert.Nil(t, err) - assert.False(t, res.Truncated) + require.NoError(t, err) + require.False(t, res.Truncated) } type testDNSCryptHandler struct{} diff --git a/upstream/upstream_doh.go b/upstream/upstream_doh.go index 7c0fe06ca..a7ecf415b 100644 --- a/upstream/upstream_doh.go +++ b/upstream/upstream_doh.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/url" + "runtime" "sync" "time" @@ -57,7 +58,7 @@ type dnsOverHTTPS struct { } // type check -var _ Upstream = &dnsOverHTTPS{} +var _ Upstream = (*dnsOverHTTPS)(nil) // newDoH returns the DNS-over-HTTPS Upstream. func newDoH(uu *url.URL, opts *Options) (u Upstream, err error) { @@ -69,14 +70,18 @@ func newDoH(uu *url.URL, opts *Options) (u Upstream, err error) { return nil, fmt.Errorf("creating https bootstrapper: %w", err) } - return &dnsOverHTTPS{ + u = &dnsOverHTTPS{ boot: b, quicConfig: &quic.Config{ KeepAlivePeriod: QUICKeepAlivePeriod, TokenStore: newQUICTokenStore(), }, - }, nil + } + + runtime.SetFinalizer(u, (*dnsOverHTTPS).Close) + + return u, nil } // Address implements the Upstream interface for *dnsOverHTTPS. @@ -128,6 +133,26 @@ func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return resp, err } +// Close implements the Upstream interface for *dnsOverHTTPS. +func (p *dnsOverHTTPS) Close() (err error) { + p.clientGuard.Lock() + defer p.clientGuard.Unlock() + + runtime.SetFinalizer(p, nil) + + if p.client == nil { + return nil + } + + // We should only explicitly close it when the client is for DoH3. Native + // http.Client is stateless and does not require explicit cleanup. + if t, ok := p.client.Transport.(*http3.RoundTripper); ok { + err = t.Close() + } + + return err +} + // exchangeHTTPS creates an HTTP client and sends the DNS query using it. func (p *dnsOverHTTPS) exchangeHTTPS(m *dns.Msg) (resp *dns.Msg, err error) { client, err := p.getClient() diff --git a/upstream/upstream_doh_test.go b/upstream/upstream_doh_test.go index 85b1e18d0..e4498e77b 100644 --- a/upstream/upstream_doh_test.go +++ b/upstream/upstream_doh_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/golibs/testutil" "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/http3" "github.com/miekg/dns" @@ -85,6 +86,7 @@ func TestUpstreamDoH(t *testing.T) { } u, err := AddressToUpstream(address, opts) require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) // Test that it responds properly. for i := 0; i < 10; i++ { @@ -179,6 +181,7 @@ func TestUpstreamDoH_raceReconnect(t *testing.T) { } u, err := AddressToUpstream(address, opts) require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) checkRaceCondition(u) }) @@ -218,6 +221,7 @@ func TestUpstreamDoH_serverRestart(t *testing.T) { }, ) require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) // Test that the upstream works properly. checkUpstream(t, u, address) diff --git a/upstream/upstream_dot.go b/upstream/upstream_dot.go index 380478d34..ce004a366 100644 --- a/upstream/upstream_dot.go +++ b/upstream/upstream_dot.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "net/url" + "runtime" "sync" "github.com/AdguardTeam/golibs/errors" @@ -14,14 +15,13 @@ import ( // dnsOverTLS is a struct that implements the Upstream interface for the // DNS-over-TLS protocol. type dnsOverTLS struct { - boot *bootstrapper - pool *TLSPool - - sync.RWMutex // protects pool + boot *bootstrapper + pool *TLSPool + poolMu sync.Mutex } // type check -var _ Upstream = &dnsOverTLS{} +var _ Upstream = (*dnsOverTLS)(nil) // newDoT returns the DNS-over-TLS Upstream. func newDoT(uu *url.URL, opts *Options) (u Upstream, err error) { @@ -33,7 +33,11 @@ func newDoT(uu *url.URL, opts *Options) (u Upstream, err error) { return nil, fmt.Errorf("creating tls bootstrapper: %w", err) } - return &dnsOverTLS{boot: b}, nil + u = &dnsOverTLS{boot: b} + + runtime.SetFinalizer(u, (*dnsOverTLS).Close) + + return u, nil } // Address implements the Upstream interface for *dnsOverTLS. @@ -41,20 +45,9 @@ func (p *dnsOverTLS) Address() string { return p.boot.URL.String() } // Exchange implements the Upstream interface for *dnsOverTLS. func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { - var pool *TLSPool - p.RLock() - pool = p.pool - p.RUnlock() - if pool == nil { - p.Lock() - // lazy initialize it - p.pool = &TLSPool{boot: p.boot} - p.Unlock() - } + pool := p.getPool() - p.RLock() - poolConn, err := p.pool.Get() - p.RUnlock() + poolConn, err := pool.Get() if err != nil { return nil, fmt.Errorf("getting connection to %s: %w", p.Address(), err) } @@ -62,6 +55,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { logBegin(p.Address(), m) reply, err = p.exchangeConn(poolConn, m) logFinish(p.Address(), err) + if err != nil { log.Tracef("The TLS connection is expired due to %s", err) @@ -69,9 +63,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { // So we're trying to re-connect right away here. // We are forcing creation of a new connection instead of calling Get() again // as there's no guarantee that other pooled connections are intact - p.RLock() - poolConn, err = p.pool.Create() - p.RUnlock() + poolConn, err = pool.Create() if err != nil { return nil, fmt.Errorf("creating new connection to %s: %w", p.Address(), err) } @@ -83,13 +75,25 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { } if err == nil { - p.RLock() - p.pool.Put(poolConn) - p.RUnlock() + pool.Put(poolConn) } return reply, err } +// Close implements the Upstream interface for *dnsOverTLS. +func (p *dnsOverTLS) Close() (err error) { + p.poolMu.Lock() + defer p.poolMu.Unlock() + + runtime.SetFinalizer(p, nil) + + if p.pool == nil { + return nil + } + + return p.pool.Close() +} + func (p *dnsOverTLS) exchangeConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) { defer func() { if err == nil { @@ -117,3 +121,14 @@ func (p *dnsOverTLS) exchangeConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, er return reply, err } + +func (p *dnsOverTLS) getPool() (pool *TLSPool) { + p.poolMu.Lock() + defer p.poolMu.Unlock() + + if p.pool == nil { + p.pool = &TLSPool{boot: p.boot} + } + + return p.pool +} diff --git a/upstream/upstream_plain.go b/upstream/upstream_plain.go index f5f4efd67..1855d7789 100644 --- a/upstream/upstream_plain.go +++ b/upstream/upstream_plain.go @@ -68,3 +68,9 @@ func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { return reply, err } + +// Close implements the Upstream interface for *plainDNS. +func (p *plainDNS) Close() (err error) { + // Nothing to close here. + return nil +} diff --git a/upstream/upstream_plain_test.go b/upstream/upstream_plain_test.go index d7abbaa6e..5f17e6a4f 100644 --- a/upstream/upstream_plain_test.go +++ b/upstream/upstream_plain_test.go @@ -3,29 +3,25 @@ package upstream import ( "testing" + "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" + "github.com/stretchr/testify/require" ) +// TODO(ameshkov): make this test not depend on external resources. func TestDNSTruncated(t *testing.T) { // AdGuard DNS address := "94.140.14.14:53" - // Google DNS - // address := "8.8.8.8:53" + u, err := AddressToUpstream(address, &Options{Timeout: timeout}) - if err != nil { - t.Fatalf("error while creating an upstream: %s", err) - } + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) req := new(dns.Msg) req.SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT) req.RecursionDesired = true res, err := u.Exchange(req) - if err != nil { - t.Fatalf("error while making a request: %s", err) - } - - if res.Truncated { - t.Fatalf("response must NOT be truncated") - } + require.NoError(t, err) + require.False(t, res.Truncated) } diff --git a/upstream/upstream_pool.go b/upstream/upstream_pool.go index 9c3ddd3bc..fbd0097ba 100644 --- a/upstream/upstream_pool.go +++ b/upstream/upstream_pool.go @@ -4,10 +4,12 @@ import ( "context" "crypto/tls" "fmt" + "io" "net" "sync" "time" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) @@ -36,33 +38,38 @@ type TLSPool struct { boot *bootstrapper // conns is the list of connections available in the pool. - conns []net.Conn - // connsMutex protects conns. - connsMutex sync.Mutex + conns []net.Conn + connsMu sync.Mutex } +// type check +var _ io.Closer = (*TLSPool)(nil) + // Get gets a connection from the pool (if there's one available) or creates // a new TLS connection. -func (n *TLSPool) Get() (net.Conn, error) { +func (n *TLSPool) Get() (conn net.Conn, err error) { // Get the connection from the slice inside the lock. - var c net.Conn - n.connsMutex.Lock() + n.connsMu.Lock() num := len(n.conns) if num > 0 { last := num - 1 - c = n.conns[last] + conn = n.conns[last] n.conns = n.conns[:last] } - n.connsMutex.Unlock() + n.connsMu.Unlock() // If we got connection from the slice, update deadline and return it. - if c != nil { - err := c.SetDeadline(time.Now().Add(dialTimeout)) + if conn != nil { + err = conn.SetDeadline(time.Now().Add(dialTimeout)) // If deadLine can't be updated it means that connection was already closed if err == nil { - log.Tracef("Returning existing connection to %s with updated deadLine", c.RemoteAddr()) - return c, nil + log.Tracef( + "Returning existing connection to %s with updated deadLine", + conn.RemoteAddr(), + ) + + return conn, nil } } @@ -70,14 +77,13 @@ func (n *TLSPool) Get() (net.Conn, error) { } // Create creates a new connection for the pool (but not puts it there). -func (n *TLSPool) Create() (net.Conn, error) { +func (n *TLSPool) Create() (conn net.Conn, err error) { tlsConfig, dialContext, err := n.boot.get() if err != nil { return nil, err } - // we'll need a new connection, dial now - conn, err := tlsDial(dialContext, "tcp", tlsConfig) + conn, err = tlsDial(dialContext, "tcp", tlsConfig) if err != nil { return nil, fmt.Errorf("connecting to %s: %w", tlsConfig.ServerName, err) } @@ -86,13 +92,35 @@ func (n *TLSPool) Create() (net.Conn, error) { } // Put returns the connection to the pool. -func (n *TLSPool) Put(c net.Conn) { - if c == nil { +func (n *TLSPool) Put(conn net.Conn) { + if conn == nil { return } - n.connsMutex.Lock() - n.conns = append(n.conns, c) - n.connsMutex.Unlock() + + n.connsMu.Lock() + defer n.connsMu.Unlock() + + n.conns = append(n.conns, conn) +} + +// Close implements io.Closer for *TLSPool. +func (n *TLSPool) Close() (err error) { + n.connsMu.Lock() + defer n.connsMu.Unlock() + + var closeErrs []error + for _, c := range n.conns { + cErr := c.Close() + if cErr != nil { + closeErrs = append(closeErrs, cErr) + } + } + + if len(closeErrs) > 0 { + return errors.List("failed to close some connections", closeErrs...) + } + + return nil } // tlsDial is basically the same as tls.DialWithDialer, but we will call our own @@ -108,17 +136,17 @@ func tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls. // We want the timeout to cover the whole process: TCP connection and // TLS handshake dialTimeout will be used as connection deadLine. conn := tls.Client(rawConn, config) + err = conn.SetDeadline(time.Now().Add(dialTimeout)) if err != nil { - log.Printf("DeadLine is not supported cause: %s", err) - conn.Close() - return nil, err + // Must not happen in normal circumstances. + panic(fmt.Errorf("cannot set deadline: %w", err)) } err = conn.Handshake() if err != nil { - conn.Close() - return nil, err + return nil, errors.WithDeferred(err, conn.Close()) } + return conn, nil } diff --git a/upstream/upstream_pool_test.go b/upstream/upstream_pool_test.go index 85b73ad3b..4b7389888 100644 --- a/upstream/upstream_pool_test.go +++ b/upstream/upstream_pool_test.go @@ -5,9 +5,11 @@ import ( "testing" "time" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/require" ) +// TODO(ameshkov): make it not depend on external servers. func TestTLSPoolReconnect(t *testing.T) { var lastState tls.ConnectionState u, err := AddressToUpstream( @@ -22,6 +24,7 @@ func TestTLSPoolReconnect(t *testing.T) { }, ) require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) // Send the first test message. req := createTestMessage() @@ -41,15 +44,15 @@ func TestTLSPoolReconnect(t *testing.T) { require.NoError(t, err) requireResponse(t, req, reply) - // Now assert that the number of connections in the pool is not changed + // Now assert that the number of connections in the pool is not changed. require.Len(t, p.pool.conns, 1) // Check that the session was resumed on the last attempt. require.True(t, lastState.DidResume) } +// TODO(ameshkov): make it not depend on external servers. func TestTLSPoolDeadLine(t *testing.T) { - // Create TLS upstream u, err := AddressToUpstream( "tls://one.one.one.one", &Options{ @@ -57,58 +60,43 @@ func TestTLSPoolDeadLine(t *testing.T) { Timeout: timeout, }, ) - if err != nil { - t.Fatalf("cannot create upstream: %s", err) - } + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) - // Send the first test message + // Send the first test message. req := createTestMessage() response, err := u.Exchange(req) - if err != nil { - t.Fatalf("first DNS message failed: %s", err) - } + require.NoError(t, err) requireResponse(t, req, response) p := u.(*dnsOverTLS) - // Now let's get connection from the pool and use it + // Now let's get connection from the pool and use it. conn, err := p.pool.Get() - if err != nil { - t.Fatalf("couldn't get connection from pool: %s", err) - } + require.NoError(t, err) + response, err = p.exchangeConn(conn, req) - if err != nil { - t.Fatalf("first DNS message failed: %s", err) - } + require.NoError(t, err) requireResponse(t, req, response) - // Update connection's deadLine and put it back to the pool + // Update connection's deadLine and put it back to the pool. err = conn.SetDeadline(time.Now().Add(10 * time.Hour)) - if err != nil { - t.Fatalf("can't set new deadLine for connection. Looks like it's already closed: %s", err) - } + require.NoError(t, err) p.pool.Put(conn) - // Get connection from the pool and reuse it + // Get connection from the pool and reuse it. conn, err = p.pool.Get() - if err != nil { - t.Fatalf("couldn't get connection from pool: %s", err) - } + require.NoError(t, err) + response, err = p.exchangeConn(conn, req) - if err != nil { - t.Fatalf("first DNS message failed: %s", err) - } + require.NoError(t, err) requireResponse(t, req, response) - // Set connection's deadLine to the past and try to reuse it + // Set connection's deadLine to the past and try to reuse it. err = conn.SetDeadline(time.Now().Add(-10 * time.Hour)) - if err != nil { - t.Fatalf("can't set new deadLine for connection. Looks like it's already closed: %s", err) - } + require.NoError(t, err) - // Connection with expired deadLine can't be used + // Connection with expired deadLine can't be used. response, err = p.exchangeConn(conn, req) - if err == nil { - t.Fatalf("this connection should be already closed, got response %s", response) - } + require.Error(t, err) } diff --git a/upstream/upstream_quic.go b/upstream/upstream_quic.go index 5a028bcd3..dafd9fc02 100644 --- a/upstream/upstream_quic.go +++ b/upstream/upstream_quic.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/url" + "runtime" "sync" "time" @@ -49,8 +50,8 @@ type dnsOverQUIC struct { // conn is the current active QUIC connection. It can be closed and // re-opened when needed. - conn quic.Connection - connGuard sync.RWMutex + conn quic.Connection + connMu sync.RWMutex // bytesPool is a *sync.Pool we use to store byte buffers in. These byte // buffers are used to read responses from the upstream. @@ -59,7 +60,7 @@ type dnsOverQUIC struct { } // type check -var _ Upstream = &dnsOverQUIC{} +var _ Upstream = (*dnsOverQUIC)(nil) // newDoQ returns the DNS-over-QUIC Upstream. func newDoQ(uu *url.URL, opts *Options) (u Upstream, err error) { @@ -71,13 +72,17 @@ func newDoQ(uu *url.URL, opts *Options) (u Upstream, err error) { return nil, fmt.Errorf("creating quic bootstrapper: %w", err) } - return &dnsOverQUIC{ + u = &dnsOverQUIC{ boot: b, quicConfig: &quic.Config{ KeepAlivePeriod: QUICKeepAlivePeriod, TokenStore: newQUICTokenStore(), }, - }, nil + } + + runtime.SetFinalizer(u, (*dnsOverQUIC).Close) + + return u, nil } // Address implements the Upstream interface for *dnsOverQUIC. @@ -128,6 +133,20 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return resp, err } +// Close implements the Upstream interface for *dnsOverQUIC. +func (p *dnsOverQUIC) Close() (err error) { + p.connMu.Lock() + defer p.connMu.Unlock() + + runtime.SetFinalizer(p, nil) + + if p.conn != nil { + err = p.conn.CloseWithError(QUICCodeNoError, "") + } + + return err +} + // exchangeQUIC attempts to open a QUIC connection, send the DNS message // through it and return the response it got from the server. func (p *dnsOverQUIC) exchangeQUIC(m *dns.Msg) (resp *dns.Msg, err error) { @@ -193,10 +212,10 @@ func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) { // close the existing one if needed. func (p *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { var conn quic.Connection - p.connGuard.RLock() + p.connMu.RLock() conn = p.conn if conn != nil && useCached { - p.connGuard.RUnlock() + p.connMu.RUnlock() return conn, nil } @@ -204,10 +223,10 @@ func (p *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { // we're recreating the connection, let's create a new one. _ = conn.CloseWithError(QUICCodeNoError, "") } - p.connGuard.RUnlock() + p.connMu.RUnlock() - p.connGuard.Lock() - defer p.connGuard.Unlock() + p.connMu.Lock() + defer p.connMu.Unlock() var err error conn, err = p.openConnection() @@ -221,8 +240,8 @@ func (p *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { // hasConnection returns true if there's an active QUIC connection. func (p *dnsOverQUIC) hasConnection() (ok bool) { - p.connGuard.Lock() - defer p.connGuard.Unlock() + p.connMu.Lock() + defer p.connMu.Unlock() return p.conn != nil } @@ -305,8 +324,8 @@ func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { // new queries were processed in another connection. We can do that in the case // of a fatal error. func (p *dnsOverQUIC) closeConnWithError(err error) { - p.connGuard.Lock() - defer p.connGuard.Unlock() + p.connMu.Lock() + defer p.connMu.Unlock() if p.conn == nil { // Do nothing, there's no active conn anyways. diff --git a/upstream/upstream_quic_test.go b/upstream/upstream_quic_test.go index bb49df6fd..2bc17b88d 100644 --- a/upstream/upstream_quic_test.go +++ b/upstream/upstream_quic_test.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxyutil" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/testutil" "github.com/lucas-clemente/quic-go" "github.com/miekg/dns" "github.com/stretchr/testify/require" @@ -33,6 +34,7 @@ func TestUpstreamDoQ(t *testing.T) { } u, err := AddressToUpstream(address, opts) require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) uq := u.(*dnsOverQUIC) var conn quic.Connection @@ -62,6 +64,7 @@ func TestUpstreamDoQ(t *testing.T) { // check it for race conditions. u, err = AddressToUpstream(address, opts) require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) checkRaceCondition(u) } @@ -74,6 +77,7 @@ func TestUpstreamDoQ_serverRestart(t *testing.T) { address := fmt.Sprintf("quic://%s", srv.addr) u, err := AddressToUpstream(address, &Options{InsecureSkipVerify: true, Timeout: time.Second}) require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) // Test that the upstream works properly. checkUpstream(t, u, address) diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go index d290c0c21..b33e5e863 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_test.go @@ -25,6 +25,8 @@ import ( "github.com/stretchr/testify/require" ) +// TODO(ameshkov): make tests here not depend on external servers. + func TestMain(m *testing.M) { // Disable logging in tests. log.SetOutput(io.Discard) @@ -44,6 +46,7 @@ func TestBootstrapTimeout(t *testing.T) { Timeout: timeout, }) require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) ch := make(chan int, count) abort := make(chan string, 1) @@ -91,9 +94,8 @@ func TestUpstreamRace(t *testing.T) { "tls://1.1.1.1", &Options{Timeout: timeout}, ) - if err != nil { - t.Fatalf("cannot create upstream: %s", err) - } + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) ch := make(chan int, count) abort := make(chan string, 1) @@ -213,6 +215,7 @@ func TestUpstreams(t *testing.T) { &Options{Bootstrap: test.bootstrap, Timeout: timeout}, ) require.NoErrorf(t, err, "failed to generate upstream from address %s", test.address) + testutil.CleanupAndRequireSuccess(t, u.Close) checkUpstream(t, u, test.address) }) @@ -260,6 +263,7 @@ func TestAddressToUpstream(t *testing.T) { t.Run(tc.addr, func(t *testing.T) { u, err := AddressToUpstream(tc.addr, tc.opt) require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, u.Close) assert.Equal(t, tc.want, u.Address()) }) @@ -315,6 +319,7 @@ func TestUpstreamDoTBootstrap(t *testing.T) { Timeout: timeout, }) require.NoErrorf(t, err, "failed to generate upstream from address %s", tc.address) + testutil.CleanupAndRequireSuccess(t, u.Close) checkUpstream(t, u, tc.address) }) @@ -327,6 +332,7 @@ func TestUpstreamDefaultOptions(t *testing.T) { for _, address := range addresses { u, err := AddressToUpstream(address, nil) require.NoErrorf(t, err, "failed to generate upstream from address %s", address) + testutil.CleanupAndRequireSuccess(t, u.Close) checkUpstream(t, u, address) } @@ -366,6 +372,7 @@ func TestUpstreamsInvalidBootstrap(t *testing.T) { Timeout: timeout, }) require.NoErrorf(t, err, "failed to generate upstream from address %s", tc.address) + testutil.CleanupAndRequireSuccess(t, u.Close) checkUpstream(t, u, tc.address) }) @@ -415,6 +422,7 @@ func TestUpstreamsWithServerIP(t *testing.T) { if err != nil { t.Fatalf("Failed to generate upstream from address %s: %s", tc.address, err) } + testutil.CleanupAndRequireSuccess(t, u.Close) t.Run(tc.address, func(t *testing.T) { checkUpstream(t, u, tc.address)