Skip to content

Commit

Permalink
upstream: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Jun 9, 2023
1 parent 477f0cd commit aa64e19
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 18 deletions.
4 changes: 2 additions & 2 deletions upstream/upstream_dnscrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (resp *dns.Msg, err error) {
q := &m.Question[0]
log.Debug("dnscrypt %s: received truncated, falling back to tcp with %s", p.addr, q)

tcpClient := &dnscrypt.Client{Timeout: p.timeout, Net: string(networkTCP)}
tcpClient := &dnscrypt.Client{Timeout: p.timeout, Net: networkTCP}
resp, err = tcpClient.Exchange(m, resolverInfo)
}
if err == nil && resp != nil && resp.Id != m.Id {
Expand All @@ -125,7 +125,7 @@ func (p *dnsCrypt) resetClient() (client *dnscrypt.Client, ri *dnscrypt.Resolver
addr := p.Address()

// Use UDP for DNSCrypt upstreams by default.
client = &dnscrypt.Client{Timeout: p.timeout, Net: string(networkUDP)}
client = &dnscrypt.Client{Timeout: p.timeout, Net: networkUDP}
ri, err = client.Dial(addr)
if err != nil {
// Trigger client and server info renewal on the next request.
Expand Down
25 changes: 13 additions & 12 deletions upstream/upstream_dnscrypt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import (

// Helpers

// DNSCryptHandlerFunc is a function-based implementation of the
// dnsCryptHandlerFunc is a function-based implementation of the
// [dnscrypt.Handler] interface.
type DNSCryptHandlerFunc func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error)
type dnsCryptHandlerFunc func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error)

// ServeDNS implements the [dnscrypt.Handler] interface for DNSCryptHandlerFunc.
func (f DNSCryptHandlerFunc) ServeDNS(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
func (f dnsCryptHandlerFunc) ServeDNS(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
return f(w, r)
}

Expand All @@ -48,7 +48,10 @@ func startTestDNSCryptServer(
Handler: h,
}
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return s.Shutdown(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

return s.Shutdown(ctx)
})

localhost := netutil.IPv4Localhost().AsSlice()
Expand All @@ -66,17 +69,15 @@ func startTestDNSCryptServer(
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, udpConn.Close)

pt := testutil.PanicT{}

// Start the server.
go func() {
udpErr := s.ServeUDP(udpConn)
require.ErrorIs(pt, udpErr, net.ErrClosed)
require.ErrorIs(testutil.PanicT{}, udpErr, net.ErrClosed)
}()

go func() {
tcpErr := s.ServeTCP(tcpConn)
require.NoError(pt, tcpErr)
require.NoError(testutil.PanicT{}, tcpErr)
}()

stamp, err = rc.CreateStamp(udpConn.LocalAddr().String())
Expand Down Expand Up @@ -109,8 +110,8 @@ func TestDNSCrypt_Exchange_truncated(t *testing.T) {
require.NoError(t, err)

var udpNum, tcpNum atomic.Uint32
h := DNSCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
if w.RemoteAddr().Network() == string(networkUDP) {
h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
if w.RemoteAddr().Network() == networkUDP {
udpNum.Add(1)
} else {
tcpNum.Add(1)
Expand Down Expand Up @@ -156,7 +157,7 @@ func TestDNSCrypt_Exchange_deadline(t *testing.T) {
rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
require.NoError(t, err)

h := DNSCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
return nil
})

Expand All @@ -180,7 +181,7 @@ func TestDNSCrypt_Exchange_dialFail(t *testing.T) {
rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
require.NoError(t, err)

h := DNSCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
return nil
})

Expand Down
2 changes: 1 addition & 1 deletion upstream/upstream_dot.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg
func tlsDial(dialContext bootstrap.DialHandler, conf *tls.Config) (c *tls.Conn, err error) {
// We're using bootstrapped address instead of what's passed to the
// function.
rawConn, err := dialContext(context.Background(), string(networkTCP), "")
rawConn, err := dialContext(context.Background(), networkTCP, "")
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions upstream/upstream_plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (

// network is the type of the network. It's either [networkUDP] or
// [networkTCP].
type network string
type network = string

const (
// networkUDP is the UDP network.
Expand Down Expand Up @@ -50,7 +50,7 @@ var _ Upstream = &plainDNS{}
// or "tcp".
func newPlain(addr *url.URL, opts *Options) (u *plainDNS, err error) {
switch addr.Scheme {
case string(networkUDP), string(networkTCP):
case networkUDP, networkTCP:
// Go on.
default:
return nil, fmt.Errorf("unsupported url scheme: %s", addr.Scheme)
Expand Down
2 changes: 1 addition & 1 deletion upstream/upstream_plain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestUpstream_plainDNS_fallbackToTCP(t *testing.T) {
var udpReqNum, tcpReqNum atomic.Uint32
srv := startDNSServer(t, func(w dns.ResponseWriter, _ *dns.Msg) {
var resp *dns.Msg
if w.RemoteAddr().Network() == string(networkUDP) {
if w.RemoteAddr().Network() == networkUDP {
udpReqNum.Add(1)
resp = tc.udpResp
} else {
Expand Down

0 comments on commit aa64e19

Please sign in to comment.