Skip to content

Commit

Permalink
feat: reuse connection by default (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
natesales committed Aug 16, 2023
1 parent e74294a commit 0f1ce1e
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 70 deletions.
13 changes: 5 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ Application Options:
--pad Set EDNS0 padding
--http3 Use HTTP/3 for DoH
--no-id-check Disable checking of DNS response ID
--no-reuse-conn Use a new connection for each query
--recaxfr Perform recursive AXFR
-f, --format= Output format (pretty, json, yaml, raw)
(default: pretty)
--pretty-ttls Format TTLs in human readable format (default:
true)
-f, --format= Output format (pretty, json, yaml, raw) (default: pretty)
--pretty-ttls Format TTLs in human readable format (default: true)
--color Enable color output
--question Show question section
--answer Show answer section (default: true)
Expand All @@ -64,8 +63,7 @@ Application Options:
--aa Set AA (Authoritative Answer) flag in query
--ad Set AD (Authentic Data) flag in query
--cd Set CD (Checking Disabled) flag in query
--rd Set RD (Recursion Desired) flag in query
(default: true)
--rd Set RD (Recursion Desired) flag in query (default: true)
--ra Set RA (Recursion Available) flag in query
--z Set Z (Zero) flag in query
--t Set TC (Truncated) flag in query
Expand All @@ -80,8 +78,7 @@ Application Options:
--quic-alpn-tokens= QUIC ALPN tokens (default: doq, doq-i11)
--quic-no-pmtud Disable QUIC PMTU discovery
--quic-no-length-prefix Don't add RFC 9250 compliant length prefix
--default-rr-types= Default record types (default: A, AAAA, NS, MX,
TXT, CNAME)
--default-rr-types= Default record types (default: A, AAAA, NS, MX, TXT, CNAME)
--udp-buffer= Set EDNS0 UDP size in query (default: 1232)
-v, --verbose Show verbose log messages
--trace Show trace log messages
Expand Down
9 changes: 8 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type optsTemplate struct {
Pad bool `long:"pad" description:"Set EDNS0 padding"`
HTTP3 bool `long:"http3" description:"Use HTTP/3 for DoH"`
NoIDCheck bool `long:"no-id-check" description:"Disable checking of DNS response ID"`
NoReuseConn bool `long:"no-reuse-conn" description:"Use a new connection for each query"`

RecAXFR bool `long:"recaxfr" description:"Perform recursive AXFR"`

Expand Down Expand Up @@ -539,10 +540,16 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
return nil
}

// Create transport
txp, err := newTransport(server, protocol, tlsConfig)
if err != nil {
log.Fatalf("creating transport: %s", err)
}

startTime := time.Now()
var replies []*dns.Msg
for _, msg := range msgs {
reply, err := query(msg, server, protocol, tlsConfig)
reply, err := (*txp).Exchange(&msg)
if err != nil {
return err
}
Expand Down
10 changes: 7 additions & 3 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ func createQuery(
return queries
}

// query performs a DNS query and returns the reply
func query(msg dns.Msg, server, protocol string, tlsConfig *tls.Config) (*dns.Msg, error) {
// newTransport creates a new transport based on local options
func newTransport(server, protocol string, tlsConfig *tls.Config) (*transport.Transport, error) {
var ts transport.Transport

switch protocol {
Expand All @@ -120,6 +120,7 @@ func query(msg dns.Msg, server, protocol string, tlsConfig *tls.Config) (*dns.Ms
Target: server,
Proxy: opts.ODoHProxy,
TLSConfig: tlsConfig,
ReuseConn: !opts.NoReuseConn,
}
} else {
log.Debug("Using HTTP(s) transport")
Expand All @@ -131,6 +132,7 @@ func query(msg dns.Msg, server, protocol string, tlsConfig *tls.Config) (*dns.Ms
Timeout: opts.Timeout,
HTTP3: opts.HTTP3,
NoPMTUd: opts.QUICNoPMTUD,
ReuseConn: !opts.NoReuseConn,
}
}
case "quic":
Expand All @@ -140,13 +142,15 @@ func query(msg dns.Msg, server, protocol string, tlsConfig *tls.Config) (*dns.Ms
TLSConfig: tlsConfig,
NoPMTUD: opts.QUICNoPMTUD,
AddLengthPrefix: !opts.QUICNoLengthPrefix,
ReuseConn: !opts.NoReuseConn,
}
case "tls":
log.Debug("Using TLS transport")
ts = &transport.TLS{
Server: server,
TLSConfig: tlsConfig,
Timeout: opts.Timeout,
ReuseConn: !opts.NoReuseConn,
}
case "tcp":
log.Debug("Using TCP transport")
Expand All @@ -168,5 +172,5 @@ func query(msg dns.Msg, server, protocol string, tlsConfig *tls.Config) (*dns.Ms
return nil, fmt.Errorf("unknown transport protocol %s", protocol)
}

return ts.(transport.Transport).Exchange(&msg)
return &ts, nil
}
42 changes: 26 additions & 16 deletions transport/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,31 @@ type HTTP struct {
Timeout time.Duration
HTTP3 bool
NoPMTUd bool
ReuseConn bool

conn *http.Client
}

func (h *HTTP) Exchange(m *dns.Msg) (*dns.Msg, error) {
httpClient := &http.Client{
Timeout: h.Timeout,
Transport: &http.Transport{
TLSClientConfig: h.TLSConfig,
MaxConnsPerHost: 1,
MaxIdleConns: 1,
Proxy: http.ProxyFromEnvironment,
},
}
if h.HTTP3 {
log.Debug("Using HTTP/3")
httpClient.Transport = &http3.RoundTripper{
TLSClientConfig: h.TLSConfig,
QuicConfig: &quic.Config{
DisablePathMTUDiscovery: h.NoPMTUd,
if h.conn == nil || !h.ReuseConn {
h.conn = &http.Client{
Timeout: h.Timeout,
Transport: &http.Transport{
TLSClientConfig: h.TLSConfig,
MaxConnsPerHost: 1,
MaxIdleConns: 1,
Proxy: http.ProxyFromEnvironment,
},
}
if h.HTTP3 {
log.Debug("Using HTTP/3")
h.conn.Transport = &http3.RoundTripper{
TLSClientConfig: h.TLSConfig,
QuicConfig: &quic.Config{
DisablePathMTUDiscovery: h.NoPMTUd,
},
}
}
}

buf, err := m.Pack()
Expand All @@ -63,7 +68,7 @@ func (h *HTTP) Exchange(m *dns.Msg) (*dns.Msg, error) {
}

log.Debugf("[http] sending %s request to %s", h.Method, queryURL)
resp, err := httpClient.Do(req)
resp, err := h.conn.Do(req)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
Expand All @@ -87,3 +92,8 @@ func (h *HTTP) Exchange(m *dns.Msg) (*dns.Msg, error) {

return &response, nil
}

func (h *HTTP) Close() error {
h.conn.CloseIdleConnections()
return nil
}
22 changes: 16 additions & 6 deletions transport/odoh.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ type ODoH struct {
Target string
Proxy string
TLSConfig *tls.Config
ReuseConn bool

conn *http.Client
}

func (o *ODoH) Exchange(m *dns.Msg) (*dns.Msg, error) {
Expand All @@ -77,12 +80,14 @@ func (o *ODoH) Exchange(m *dns.Msg) (*dns.Msg, error) {
return nil, fmt.Errorf("new target configs request: %s", err)
}

client := http.Client{
Transport: &http.Transport{
TLSClientConfig: o.TLSConfig,
},
if o.conn == nil || !o.ReuseConn {
o.conn = &http.Client{
Transport: &http.Transport{
TLSClientConfig: o.TLSConfig,
},
}
}
resp, err := client.Do(req)
resp, err := o.conn.Do(req)
if err != nil {
return nil, fmt.Errorf("do target configs request: %s", err)
}
Expand Down Expand Up @@ -132,7 +137,7 @@ func (o *ODoH) Exchange(m *dns.Msg) (*dns.Msg, error) {
req.Header.Set("Content-Type", ODoHContentType)
req.Header.Set("Accept", ODoHContentType)

resp, err = client.Do(req)
resp, err = o.conn.Do(req)
if err != nil {
return nil, fmt.Errorf("do request: %s", err)
}
Expand Down Expand Up @@ -162,3 +167,8 @@ func (o *ODoH) Exchange(m *dns.Msg) (*dns.Msg, error) {
}
return msg, err
}

func (o *ODoH) Close() error {
o.conn.CloseIdleConnections()
return nil
}
5 changes: 5 additions & 0 deletions transport/plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ func (p *Plain) Exchange(m *dns.Msg) (*dns.Msg, error) {

return reply, err
}

// Close is a no-op for the plain transport
func (p *Plain) Close() error {
return nil
}
51 changes: 33 additions & 18 deletions transport/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ type QUIC struct {
TLSConfig *tls.Config
NoPMTUD bool
AddLengthPrefix bool
ReuseConn bool

conn *quic.Connection
}

func (q *QUIC) connection() quic.Connection {
return *q.conn
}

// setServerName sets the TLS config server name to the QUIC server
Expand All @@ -43,22 +50,25 @@ func (q *QUIC) setServerName() {
}

func (q *QUIC) Exchange(msg *dns.Msg) (*dns.Msg, error) {
q.setServerName()
if len(q.TLSConfig.NextProtos) == 0 {
log.Warn("No ALPN tokens specified, using default: \"doq\"")
q.TLSConfig.NextProtos = []string{"doq"}
}
log.Debugf("Dialing with QUIC ALPN tokens: %v", q.TLSConfig.NextProtos)
session, err := quic.DialAddr(
context.Background(),
q.Server,
q.TLSConfig,
&quic.Config{
DisablePathMTUDiscovery: q.NoPMTUD,
},
)
if err != nil {
return nil, fmt.Errorf("opening quic session to %s: %v", q.Server, err)
if q.conn == nil || !q.ReuseConn {
q.setServerName()
if len(q.TLSConfig.NextProtos) == 0 {
log.Warn("No ALPN tokens specified, using default: \"doq\"")
q.TLSConfig.NextProtos = []string{"doq"}
}
log.Debugf("Dialing with QUIC ALPN tokens: %v", q.TLSConfig.NextProtos)
conn, err := quic.DialAddr(
context.Background(),
q.Server,
q.TLSConfig,
&quic.Config{
DisablePathMTUDiscovery: q.NoPMTUD,
},
)
if err != nil {
return nil, fmt.Errorf("opening quic session to %s: %v", q.Server, err)
}
q.conn = &conn
}

// Clients and servers MUST NOT send the edns-tcp-keepalive EDNS(0) Option [RFC7828] in any messages sent
Expand All @@ -67,13 +77,14 @@ func (q *QUIC) Exchange(msg *dns.Msg) (*dns.Msg, error) {
if opt := msg.IsEdns0(); opt != nil {
for _, option := range opt.Option {
if option.Option() == dns.EDNS0TCPKEEPALIVE {
_ = session.CloseWithError(DoQProtocolError, "") // Already closing the connection, so we don't care about the error
_ = q.connection().CloseWithError(DoQProtocolError, "") // Already closing the connection, so we don't care about the error
q.conn = nil
return nil, fmt.Errorf("EDNS0 TCP keepalive option is set")
}
}
}

stream, err := session.OpenStream()
stream, err := q.connection().OpenStream()
if err != nil {
return nil, fmt.Errorf("open new stream to %s: %v", q.Server, err)
}
Expand Down Expand Up @@ -137,3 +148,7 @@ func addPrefix(b []byte) (m []byte) {

return m
}

func (q *QUIC) Close() error {
return q.connection().CloseWithError(DoQNoError, "")
}
44 changes: 28 additions & 16 deletions transport/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,42 @@ type TLS struct {
Server string
TLSConfig *tls.Config
Timeout time.Duration
ReuseConn bool

conn *tls.Conn
}

func (t *TLS) Exchange(msg *dns.Msg) (*dns.Msg, error) {
conn, err := tls.DialWithDialer(
&net.Dialer{
Timeout: t.Timeout,
},
"tcp",
t.Server,
t.TLSConfig,
)
if err != nil {
return nil, err
}
defer conn.Close()

if err = conn.Handshake(); err != nil {
return nil, err
if t.conn == nil || !t.ReuseConn {
var err error
t.conn, err = tls.DialWithDialer(
&net.Dialer{
Timeout: t.Timeout,
},
"tcp",
t.Server,
t.TLSConfig,
)
if err != nil {
return nil, err
}
if err = t.conn.Handshake(); err != nil {
return nil, err
}
}

c := dns.Conn{Conn: conn}
c := dns.Conn{Conn: t.conn}
if err := c.WriteMsg(msg); err != nil {
return nil, fmt.Errorf("write msg to %s: %v", t.Server, err)
}

return c.ReadMsg()
}

// Close closes the TLS connection
func (t *TLS) Close() error {
if t.conn != nil {
return t.conn.Close()
}
return nil
}
5 changes: 3 additions & 2 deletions transport/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import (

func tlsTransport() *TLS {
return &TLS{
Server: "dns.quad9.net:853",
Timeout: 1 * time.Second,
Server: "dns.quad9.net:853",
Timeout: 1 * time.Second,
ReuseConn: false,
}
}
1 change: 1 addition & 0 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import "github.com/miekg/dns"

type Transport interface {
Exchange(*dns.Msg) (*dns.Msg, error)
Close() error
}

// Interface guards
Expand Down
Loading

0 comments on commit 0f1ce1e

Please sign in to comment.