Skip to content

Commit

Permalink
Pull request: Upstream now implements io.Closer.
Browse files Browse the repository at this point in the history
Merge in DNS/dnsproxy from upstream_closer to master

Squashed commit of the following:

commit 3ac92bc
Author: Andrey Meshkov <am@adguard.com>
Date:   Mon Oct 17 23:53:22 2022 +0300

    fix formatting

commit 3c749a8
Author: Andrey Meshkov <am@adguard.com>
Date:   Mon Oct 17 23:48:04 2022 +0300

    Upstream now implements io.Closer.

    This is rather important because some of the Upstream implementations
    actually require explicit cleanup.  However, there's a lot of old code
    that is not aware of the fact that Upstream can be cleaned up. In order
    to make the life easier for the authors, I used runtime.SetFinalizer
    where possible to guarantee cleanup.
  • Loading branch information
ameshkov committed Oct 18, 2022
1 parent 571baaf commit 95ef855
Show file tree
Hide file tree
Showing 18 changed files with 351 additions and 195 deletions.
14 changes: 12 additions & 2 deletions fastip/fastest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,18 @@ type errUpstream struct {
err error
}

func (u errUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) {
func (u errUpstream) Exchange(_ *dns.Msg) (*dns.Msg, error) {
return nil, u.err
}

type testAUpstream struct {
recs []*dns.A
}

// type check
var _ upstream.Upstream = (*testAUpstream)(nil)

// Exchange implements the upstream.Upstream interface for *testAUpstream.
func (u *testAUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)
Expand All @@ -100,10 +104,16 @@ func (u *testAUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
return resp, nil
}

func (u *testAUpstream) Address() string {
// Address implements the upstream.Upstream interface for *testAUpstream.
func (u *testAUpstream) Address() (addr string) {
return ""
}

// Close implements the upstream.Upstream interface for *testAUpstream.
func (u *testAUpstream) Close() (err error) {
return nil
}

func (u *testAUpstream) add(host string, ip net.IP) (chain *testAUpstream) {
u.recs = append(u.recs, &dns.A{
Hdr: dns.RR_Header{
Expand Down
4 changes: 4 additions & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ func (p *Proxy) Stop() error {
closeAll(p.dnsCryptTCPListen, &errs)
p.dnsCryptTCPListen = nil

if p.UpstreamConfig != nil {
closeAll([]io.Closer{p.UpstreamConfig}, &errs)
}

p.started = false
log.Println("Stopped the DNS proxy server")
if len(errs) > 0 {
Expand Down
115 changes: 62 additions & 53 deletions proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ type testDNSSECUpstream struct {
rrsig dns.RR
}

// type check
var _ upstream.Upstream = (*testDNSSECUpstream)(nil)

// Exchange implements the upstream.Upstream interface for *testDNSSECUpstream.
func (u *testDNSSECUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)
Expand Down Expand Up @@ -113,10 +117,16 @@ func (u *testDNSSECUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
return resp, nil
}

// Address implements the upstream.Upstream interface for *testDNSSECUpstream.
func (u *testDNSSECUpstream) Address() string {
return ""
}

// Close implements the upstream.Upstream interface for *testDNSSECUpstream.
func (u *testDNSSECUpstream) Close() (err error) {
return nil
}

func TestProxy_Resolve_dnssecCache(t *testing.T) {
const host = "example.com"

Expand Down Expand Up @@ -347,90 +357,71 @@ func TestUpstreamsSort(t *testing.T) {
func TestExchangeWithReservedDomains(t *testing.T) {
dnsProxy := createTestProxy(t, nil)

// upstreams specification. Domains adguard.com and google.ru reserved with fake upstreams, maps.google.ru excluded from dnsmasq.
upstreams := []string{"[/adguard.com/]1.2.3.4", "[/google.ru/]2.3.4.5", "[/maps.google.ru/]#", "1.1.1.1"}
// Upstreams specification. Domains adguard.com and google.ru reserved
// with fake upstreams, maps.google.ru excluded from dnsmasq.
upstreams := []string{
"[/adguard.com/]1.2.3.4",
"[/google.ru/]2.3.4.5",
"[/maps.google.ru/]#",
"1.1.1.1",
}
config, err := ParseUpstreamsConfig(
upstreams,
&upstream.Options{
InsecureSkipVerify: false,
Bootstrap: []string{"8.8.8.8"},
Timeout: 1 * time.Second,
})
if err != nil {
t.Fatalf("Error while upstream config parsing: %s", err)
}
},
)
require.NoError(t, err)

dnsProxy.UpstreamConfig = config

err = dnsProxy.Start()
if err != nil {
t.Fatalf("cannot start the DNS proxy: %s", err)
}
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

// create a DNS-over-TCP client connection
// Create a DNS-over-TCP client connection.
addr := dnsProxy.Addr(ProtoTCP)
conn, err := dns.Dial("tcp", addr.String())
if err != nil {
t.Fatalf("cannot connect to the proxy: %s", err)
}
require.NoError(t, err)

// create google-a test message
// Create google-a test message.
req := createTestMessage()
err = conn.WriteMsg(req)
if err != nil {
t.Fatalf("cannot write message: %s", err)
}
require.NoError(t, err)

// make sure if dnsproxy is working
// Make sure that dnsproxy is working.
res, err := conn.ReadMsg()
if err != nil {
t.Fatalf("cannot read response to message: %s", err)
}
require.NoError(t, err)
requireResponse(t, req, res)

// create adguard.com test message
// Create adguard.com test message.
req = createHostTestMessage("adguard.com")
err = conn.WriteMsg(req)
if err != nil {
t.Fatalf("cannot write message: %s", err)
}
require.NoError(t, err)

// test message should not be resolved
// Test message should not be resolved.
res, _ = conn.ReadMsg()
if res.Answer != nil {
t.Fatal("adguard.com should not be resolved")
}
require.Nil(t, res.Answer)

// create www.google.ru test message
// Create www.google.ru test message.
req = createHostTestMessage("www.google.ru")
err = conn.WriteMsg(req)
if err != nil {
t.Fatalf("cannot write message: %s", err)
}
require.NoError(t, err)

// test message should not be resolved
// Test message should not be resolved.
res, _ = conn.ReadMsg()
if res.Answer != nil {
t.Fatal("www.google.ru should not be resolved")
}
require.Nil(t, res.Answer)

// create maps.google.ru test message
// Create maps.google.ru test message.
req = createHostTestMessage("maps.google.ru")
err = conn.WriteMsg(req)
if err != nil {
t.Fatalf("cannot write message: %s", err)
}
require.NoError(t, err)

// test message should be resolved
// Test message should be resolved.
res, _ = conn.ReadMsg()
if res.Answer == nil {
t.Fatal("maps.google.ru should be resolved")
}

// Stop the proxy
err = dnsProxy.Stop()
if err != nil {
t.Fatalf("cannot stop the DNS proxy: %s", err)
}
require.NotNil(t, res.Answer)
}

// TestOneByOneUpstreamsExchange tries to resolve DNS request
Expand Down Expand Up @@ -757,6 +748,9 @@ type funcUpstream struct {
addressFunc func() (addr string)
}

// type check
var _ upstream.Upstream = (*funcUpstream)(nil)

// Exchange implements upstream.Upstream interface for *funcUpstream.
func (wu *funcUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) {
if wu.exchangeFunc == nil {
Expand All @@ -767,14 +761,19 @@ func (wu *funcUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) {
}

// Address implements upstream.Upstream interface for *funcUpstream.
func (wu *funcUpstream) Address() string {
func (wu *funcUpstream) Address() (addr string) {
if wu.addressFunc == nil {
return "stub"
}

return wu.addressFunc()
}

// Close implements upstream.Upstream interface for *funcUpstream.
func (wu *funcUpstream) Close() (err error) {
return nil
}

func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) {
dnsProxy := createTestProxy(t, nil)
require.NoError(t, dnsProxy.Start())
Expand Down Expand Up @@ -1289,6 +1288,10 @@ type testUpstream struct {
ecsReqMask int
}

// type check
var _ upstream.Upstream = (*testUpstream)(nil)

// Exchange implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)
Expand All @@ -1309,10 +1312,16 @@ func (u *testUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
return resp, nil
}

func (u *testUpstream) Address() string {
// Address implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Address() (addr string) {
return ""
}

// Close implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Close() (err error) {
return nil
}

func TestProxy_Resolve_withOptimisticResolver(t *testing.T) {
const (
host = "some.domain.name."
Expand Down
35 changes: 27 additions & 8 deletions proxy/upstreams.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package proxy

import (
"fmt"
"io"
"strings"

"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"

"github.com/AdguardTeam/dnsproxy/upstream"
)

// UpstreamConfig is a wrapper for list of default upstreams and map of reserved domains and corresponding upstreams
Expand All @@ -19,6 +20,9 @@ type UpstreamConfig struct {
SubdomainExclusions *stringutil.Set // set of domains with sub-domains exclusions
}

// type check
var _ io.Closer = (*UpstreamConfig)(nil)

// ParseUpstreamsConfig returns UpstreamConfig and error if upstreams configuration is invalid
// default upstream syntax: <upstreamString>
// reserved upstream syntax: [/domain1/../domainN/]<upstreamString>
Expand Down Expand Up @@ -159,12 +163,15 @@ func parseUpstreamLine(l string) (string, []string, error) {
return u, hosts, nil
}

// getUpstreamsForDomain looks for a domain in reserved domains map and returns a list of corresponding upstreams.
// returns default upstreams list if domain isn't found. More specific domains take priority over less specific domains.
// For example, map contains the following keys: host.com and www.host.com
// If we are looking for domain mail.host.com, this method will return value of host.com key
// If we are looking for domain www.host.com, this method will return value of www.host.com key
// If more specific domain value is nil, it means that domain was excluded and should be exchanged with default upstreams
// getUpstreamsForDomain looks for a domain in the reserved domains map and
// returns a list of corresponding upstreams. It returns default upstreams list
// if the domain was not found in the map. More specific domains take priority
// over less specific domains. For example, take a map that contains the
// following keys: host.com and www.host.com. If we are looking for domain
// mail.host.com, this method will return value of host.com key. If we are
// looking for domain www.host.com, this method will return value of the
// www.host.com key. If a more specific domain value is nil, it means that the
// domain was excluded and should be exchanged with default upstreams.
func (uc *UpstreamConfig) getUpstreamsForDomain(host string) (ups []upstream.Upstream) {
if len(uc.DomainReservedUpstreams) == 0 {
return uc.Upstreams
Expand Down Expand Up @@ -214,3 +221,15 @@ func (uc *UpstreamConfig) getUpstreamsForDomain(host string) (ups []upstream.Ups

return uc.Upstreams
}

// Close implements the io.Closer interface for *UpstreamConfig.
func (uc *UpstreamConfig) Close() (err error) {
closeErrs := []error{}
closeAll(uc.Upstreams, &closeErrs)

if len(closeErrs) > 0 {
return errors.List("failed to close some upstreams", closeErrs...)
}

return nil
}
16 changes: 13 additions & 3 deletions upstream/parallel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ type testUpstream struct {
sleep time.Duration // a delay before response
}

func (u *testUpstream) Exchange(req *dns.Msg) (*dns.Msg, error) {
// type check
var _ Upstream = (*testUpstream)(nil)

// Exchange implements the Upstream interface for *testUpstream.
func (u *testUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
if u.sleep != 0 {
time.Sleep(u.sleep)
}
Expand All @@ -115,7 +119,7 @@ func (u *testUpstream) Exchange(req *dns.Msg) (*dns.Msg, error) {
return nil, nil
}

resp := &dns.Msg{}
resp = &dns.Msg{}
resp.SetReply(req)

if len(u.a) != 0 {
Expand All @@ -131,10 +135,16 @@ func (u *testUpstream) Exchange(req *dns.Msg) (*dns.Msg, error) {
return resp, nil
}

func (u *testUpstream) Address() string {
// Address implements the Upstream interface for *testUpstream.
func (u *testUpstream) Address() (addr string) {
return ""
}

// Close implements the Upstream interface for *testUpstream.
func (u *testUpstream) Close() (err error) {
return nil
}

func TestExchangeAll(t *testing.T) {
u1 := testUpstream{}
u1.a = net.ParseIP("1.1.1.1")
Expand Down
2 changes: 2 additions & 0 deletions upstream/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net"
"net/url"
"strconv"
Expand All @@ -26,6 +27,7 @@ type Upstream interface {
Exchange(m *dns.Msg) (*dns.Msg, error)
// Address returns the address of the upstream DNS resolver.
Address() string
io.Closer
}

// Options for AddressToUpstream func. With these options we can configure the
Expand Down
6 changes: 6 additions & 0 deletions upstream/upstream_dnscrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ func (p *dnsCrypt) Exchange(m *dns.Msg) (*dns.Msg, error) {
return reply, err
}

// Close implements the Upstream interface for *dnsCrypt.
func (p *dnsCrypt) Close() (err error) {
// Nothing to close here.
return nil
}

// exchangeDNSCrypt attempts to send the DNS query and returns the response
func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (reply *dns.Msg, err error) {
p.RLock()
Expand Down
Loading

0 comments on commit 95ef855

Please sign in to comment.