From aa64e193817cb92ce478800966a2234ba1ddefee Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Fri, 9 Jun 2023 15:55:51 +0300 Subject: [PATCH] upstream: imp code --- upstream/upstream_dnscrypt.go | 4 ++-- upstream/upstream_dnscrypt_test.go | 25 +++++++++++++------------ upstream/upstream_dot.go | 2 +- upstream/upstream_plain.go | 4 ++-- upstream/upstream_plain_test.go | 2 +- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/upstream/upstream_dnscrypt.go b/upstream/upstream_dnscrypt.go index 1517845b3..f78b6dd77 100644 --- a/upstream/upstream_dnscrypt.go +++ b/upstream/upstream_dnscrypt.go @@ -109,7 +109,7 @@ func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (resp *dns.Msg, err error) { q := &m.Question[0] log.Debug("dnscrypt %s: received truncated, falling back to tcp with %s", p.addr, q) - tcpClient := &dnscrypt.Client{Timeout: p.timeout, Net: string(networkTCP)} + tcpClient := &dnscrypt.Client{Timeout: p.timeout, Net: networkTCP} resp, err = tcpClient.Exchange(m, resolverInfo) } if err == nil && resp != nil && resp.Id != m.Id { @@ -125,7 +125,7 @@ func (p *dnsCrypt) resetClient() (client *dnscrypt.Client, ri *dnscrypt.Resolver addr := p.Address() // Use UDP for DNSCrypt upstreams by default. - client = &dnscrypt.Client{Timeout: p.timeout, Net: string(networkUDP)} + client = &dnscrypt.Client{Timeout: p.timeout, Net: networkUDP} ri, err = client.Dial(addr) if err != nil { // Trigger client and server info renewal on the next request. diff --git a/upstream/upstream_dnscrypt_test.go b/upstream/upstream_dnscrypt_test.go index 73a5a2cc9..0a12237d9 100644 --- a/upstream/upstream_dnscrypt_test.go +++ b/upstream/upstream_dnscrypt_test.go @@ -21,12 +21,12 @@ import ( // Helpers -// DNSCryptHandlerFunc is a function-based implementation of the +// dnsCryptHandlerFunc is a function-based implementation of the // [dnscrypt.Handler] interface. -type DNSCryptHandlerFunc func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) +type dnsCryptHandlerFunc func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) // ServeDNS implements the [dnscrypt.Handler] interface for DNSCryptHandlerFunc. -func (f DNSCryptHandlerFunc) ServeDNS(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) { +func (f dnsCryptHandlerFunc) ServeDNS(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) { return f(w, r) } @@ -48,7 +48,10 @@ func startTestDNSCryptServer( Handler: h, } testutil.CleanupAndRequireSuccess(t, func() (err error) { - return s.Shutdown(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + return s.Shutdown(ctx) }) localhost := netutil.IPv4Localhost().AsSlice() @@ -66,17 +69,15 @@ func startTestDNSCryptServer( require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, udpConn.Close) - pt := testutil.PanicT{} - // Start the server. go func() { udpErr := s.ServeUDP(udpConn) - require.ErrorIs(pt, udpErr, net.ErrClosed) + require.ErrorIs(testutil.PanicT{}, udpErr, net.ErrClosed) }() go func() { tcpErr := s.ServeTCP(tcpConn) - require.NoError(pt, tcpErr) + require.NoError(testutil.PanicT{}, tcpErr) }() stamp, err = rc.CreateStamp(udpConn.LocalAddr().String()) @@ -109,8 +110,8 @@ func TestDNSCrypt_Exchange_truncated(t *testing.T) { require.NoError(t, err) var udpNum, tcpNum atomic.Uint32 - h := DNSCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) { - if w.RemoteAddr().Network() == string(networkUDP) { + h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) { + if w.RemoteAddr().Network() == networkUDP { udpNum.Add(1) } else { tcpNum.Add(1) @@ -156,7 +157,7 @@ func TestDNSCrypt_Exchange_deadline(t *testing.T) { rc, err := dnscrypt.GenerateResolverConfig("example.org", nil) require.NoError(t, err) - h := DNSCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) { + h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) { return nil }) @@ -180,7 +181,7 @@ func TestDNSCrypt_Exchange_dialFail(t *testing.T) { rc, err := dnscrypt.GenerateResolverConfig("example.org", nil) require.NoError(t, err) - h := DNSCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) { + h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) { return nil }) diff --git a/upstream/upstream_dot.go b/upstream/upstream_dot.go index bd30bebe7..6a0cfde1c 100644 --- a/upstream/upstream_dot.go +++ b/upstream/upstream_dot.go @@ -223,7 +223,7 @@ func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg func tlsDial(dialContext bootstrap.DialHandler, conf *tls.Config) (c *tls.Conn, err error) { // We're using bootstrapped address instead of what's passed to the // function. - rawConn, err := dialContext(context.Background(), string(networkTCP), "") + rawConn, err := dialContext(context.Background(), networkTCP, "") if err != nil { return nil, err } diff --git a/upstream/upstream_plain.go b/upstream/upstream_plain.go index 9a65c0727..f8d3be237 100644 --- a/upstream/upstream_plain.go +++ b/upstream/upstream_plain.go @@ -17,7 +17,7 @@ import ( // network is the type of the network. It's either [networkUDP] or // [networkTCP]. -type network string +type network = string const ( // networkUDP is the UDP network. @@ -50,7 +50,7 @@ var _ Upstream = &plainDNS{} // or "tcp". func newPlain(addr *url.URL, opts *Options) (u *plainDNS, err error) { switch addr.Scheme { - case string(networkUDP), string(networkTCP): + case networkUDP, networkTCP: // Go on. default: return nil, fmt.Errorf("unsupported url scheme: %s", addr.Scheme) diff --git a/upstream/upstream_plain_test.go b/upstream/upstream_plain_test.go index 44f947fd9..55a8216f2 100644 --- a/upstream/upstream_plain_test.go +++ b/upstream/upstream_plain_test.go @@ -108,7 +108,7 @@ func TestUpstream_plainDNS_fallbackToTCP(t *testing.T) { var udpReqNum, tcpReqNum atomic.Uint32 srv := startDNSServer(t, func(w dns.ResponseWriter, _ *dns.Msg) { var resp *dns.Msg - if w.RemoteAddr().Network() == string(networkUDP) { + if w.RemoteAddr().Network() == networkUDP { udpReqNum.Add(1) resp = tc.udpResp } else {