Skip to content

Commit

Permalink
Pull request 256: fix-poisoning
Browse files Browse the repository at this point in the history
Merge in GO/dnsproxy from fix-poisoning to master

Squashed commit of the following:

commit 2294054
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue May 30 17:01:32 2023 +0300

    upstream: imp code

commit 100b023
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon May 29 19:18:35 2023 +0300

    upstream: imp code

commit 2161ced
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon May 29 19:02:56 2023 +0300

    all: validate responses, fallback to tcp
  • Loading branch information
EugeneOne1 committed May 31, 2023
1 parent 936bd45 commit 8956a92
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 62 deletions.
2 changes: 1 addition & 1 deletion internal/bootstrap/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type lookupResult struct {

// lookupAsync tries to lookup for ip of host with r and sends the result into
// resCh. It's inteneded to be used as a goroutine.
func lookupAsync(ctx context.Context, r Resolver, host string, resCh chan *lookupResult) {
func lookupAsync(ctx context.Context, r Resolver, host string, resCh chan<- *lookupResult) {
defer log.OnPanic("parallel lookup")

addrs, err := lookup(ctx, r, host)
Expand Down
4 changes: 2 additions & 2 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,8 @@ func (p *Proxy) Resolve(dctx *DNSContext) (err error) {
ok, err = p.replyFromUpstream(dctx)

// Don't cache the responses having CD flag, just like Dnsmasq does. It
// prevents the cache from being poisoned with unvalidated answers which
// may differ from validated ones.
// prevents the cache from being poisoned with unvalidated answers which may
// differ from validated ones.
//
// See https://github.com/imp/dnsmasq/blob/770bce967cfc9967273d0acfb3ea018fb7b17522/src/forward.c#L1169-L1172.
if cacheWorks && ok && !dctx.Res.CheckingDisabled {
Expand Down
6 changes: 3 additions & 3 deletions upstream/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import (

// Upstream is an interface for a DNS resolver.
type Upstream interface {
// Exchange sends the DNS query m to this upstream and returns the response
// that has been received or an error if something went wrong.
Exchange(m *dns.Msg) (*dns.Msg, error)
// Exchange sends the DNS query req to this upstream and returns the
// response that has been received or an error if something went wrong.
Exchange(req *dns.Msg) (*dns.Msg, error)

// Address returns the address of the upstream DNS resolver.
Address() string
Expand Down
2 changes: 1 addition & 1 deletion upstream/upstream_doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan err
func (p *dnsOverHTTPS) probeTLS(dialContext bootstrap.DialHandler, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now()

conn, err := tlsDial(dialContext, "tcp", tlsConfig)
conn, err := tlsDial(dialContext, tlsConfig)
if err != nil {
ch <- fmt.Errorf("opening TLS connection: %w", err)
return
Expand Down
16 changes: 6 additions & 10 deletions upstream/upstream_dot.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) {
log.Debug("dot %s: bad conn from pool: %s", p.addr, err)

// Retry.
conn, err = tlsDial(h, "tcp", p.tlsConf.Clone())
conn, err = tlsDial(h, p.tlsConf.Clone())
if err != nil {
return nil, fmt.Errorf(
"dialing %s: connecting to %s: %w",
Expand Down Expand Up @@ -158,7 +158,7 @@ func (p *dnsOverTLS) conn(h bootstrap.DialHandler) (conn net.Conn, err error) {
// Dial a new connection outside the lock, if needed.
defer func() {
if conn == nil {
conn, err = tlsDial(h, "tcp", p.tlsConf.Clone())
conn, err = tlsDial(h, p.tlsConf.Clone())
err = errors.Annotate(err, "connecting to %s: %w", p.tlsConf.ServerName)
}
}()
Expand Down Expand Up @@ -220,20 +220,16 @@ func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg

// tlsDial is basically the same as tls.DialWithDialer, but we will call our own
// dialContext function to get connection.
func tlsDial(
dialContext bootstrap.DialHandler,
network string,
conf *tls.Config,
) (c *tls.Conn, err error) {
func tlsDial(dialContext bootstrap.DialHandler, conf *tls.Config) (c *tls.Conn, err error) {
// We're using bootstrapped address instead of what's passed to the
// function.
rawConn, err := dialContext(context.Background(), network, "")
rawConn, err := dialContext(context.Background(), string(networkTCP), "")
if err != nil {
return nil, err
}

// We want the timeout to cover the whole process: TCP connection and
// TLS handshake dialTimeout will be used as connection deadLine.
// We want the timeout to cover the whole process: TCP connection and TLS
// handshake dialTimeout will be used as connection deadLine.
conn := tls.Client(rawConn, conf)
err = conn.SetDeadline(time.Now().Add(dialTimeout))
if err != nil {
Expand Down
128 changes: 104 additions & 24 deletions upstream/upstream_plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,30 @@ package upstream
import (
"context"
"fmt"
"io"
"net"
"net/url"
"strings"
"time"

"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)

// network is the type of the network. It's either [networkUDP] or
// [networkTCP].
type network string

const (
// networkUDP is the UDP network.
networkUDP network = "udp"

// networkTCP is the TCP network.
networkTCP network = "tcp"
)

// plainDNS implements the [Upstream] interface for the regular DNS protocol.
type plainDNS struct {
// addr is the DNS server URL. Scheme is always "udp" or "tcp".
Expand All @@ -20,15 +36,26 @@ type plainDNS struct {
// one.
getDialer DialerInitializer

// net is the network of the connections.
net network

// timeout is the timeout for DNS requests.
timeout time.Duration
}

// type check
var _ Upstream = &plainDNS{}

// newPlain returns the plain DNS Upstream.
// newPlain returns the plain DNS Upstream. addr.Scheme should be either "udp"
// or "tcp".
func newPlain(addr *url.URL, opts *Options) (u *plainDNS, err error) {
switch addr.Scheme {
case string(networkUDP), string(networkTCP):
// Go on.
default:
return nil, fmt.Errorf("unsupported url scheme: %s", addr.Scheme)
}

addPort(addr, defaultPortPlain)

getDialer, err := newDialerInitializer(addr, opts)
Expand All @@ -39,50 +66,74 @@ func newPlain(addr *url.URL, opts *Options) (u *plainDNS, err error) {
return &plainDNS{
addr: addr,
getDialer: getDialer,
net: network(addr.Scheme),
timeout: opts.Timeout,
}, nil
}

// Address implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Address() string {
if p.addr.Scheme == "udp" {
switch p.net {
case networkUDP:
return p.addr.Host
case networkTCP:
return p.addr.String()
default:
panic(fmt.Sprintf("unexpected network: %s", p.net))
}

return p.addr.String()
}

// dialExchange performs a DNS exchange with the specified dial handler.
// network must be either "udp" or "tcp".
// network must be either [networkUDP] or [networkTCP].
func (p *plainDNS) dialExchange(
network string,
network network,
dial bootstrap.DialHandler,
m *dns.Msg,
req *dns.Msg,
) (resp *dns.Msg, err error) {
addr := p.Address()
client := &dns.Client{Timeout: p.timeout}

conn := &dns.Conn{}
if network == "udp" {
if network == networkUDP {
conn.UDPSize = dns.MinMsgSize
}

logBegin(addr, m)
conn.Conn, err = dial(context.Background(), network, "")
if err != nil {
logFinish(addr, err)
logBegin(addr, req)
defer func() { logFinish(addr, err) }()

ctx := context.Background()
conn.Conn, err = dial(ctx, string(network), "")
if err != nil {
return nil, fmt.Errorf("dialing %s over %s: %w", p.addr.Host, network, err)
}

resp, _, err = client.ExchangeWithConn(m, conn)
logFinish(addr, err)
resp, _, err = client.ExchangeWithConn(req, conn)
if isExpectedConnErr(err) {
conn.Conn, err = dial(ctx, string(network), "")
if err != nil {
return nil, fmt.Errorf("dialing %s over %s again: %w", p.addr.Host, network, err)
}

resp, _, err = client.ExchangeWithConn(req, conn)
}

if err != nil {
return resp, fmt.Errorf("exchanging with %s over %s: %w", addr, network, err)
}

return resp, err
return resp, validatePlainResponse(req, resp)
}

// isExpectedConnErr returns true if the error is expected. In this case,
// we will make a second attempt to process the request.
func isExpectedConnErr(err error) (is bool) {
var netErr net.Error

return err != nil && (errors.As(err, &netErr) || errors.Is(err, io.EOF))
}

// Exchange implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
func (p *plainDNS) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
dial, err := p.getDialer()
if err != nil {
// Don't wrap the error since it's informative enough as is.
Expand All @@ -91,21 +142,50 @@ func (p *plainDNS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {

addr := p.Address()

resp, err = p.dialExchange(p.addr.Scheme, dial, m)
if p.addr.Scheme == "udp" {
if resp == nil || !resp.Truncated {
return resp, err
}
resp, err = p.dialExchange(p.net, dial, req)
if p.net != networkUDP {
return resp, err
}

log.Debug("plain %s: received truncated, falling back to tcp with %s", addr, &m.Question[0])
if resp == nil {
return resp, err
}

resp, err = p.dialExchange("tcp", dial, m)
if errors.Is(err, errQuestion) {
log.Debug("plain %s: %s, using tcp", addr, err)
} else if resp.Truncated {
log.Debug("plain %s: resp for %s is truncated, using tcp", &req.Question[0], addr)
}

return resp, err
return p.dialExchange(networkTCP, dial, req)
}

// Close implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Close() (err error) {
return nil
}

// errQuestion is returned when a message has malformed question section.
const errQuestion errors.Error = "bad question section"

// validatePlainResponse validates resp from an upstream DNS server for
// compliance with req. Any error returned wraps [ErrQuestion], since it
// essentially validates the question section of resp.
func validatePlainResponse(req, resp *dns.Msg) (err error) {
if qlen := len(resp.Question); qlen != 1 {
return fmt.Errorf("%w: only 1 question allowed; got %d", errQuestion, qlen)
}

reqQ, respQ := req.Question[0], resp.Question[0]

if reqQ.Qtype != respQ.Qtype {
return fmt.Errorf("%w: mismatched type %s", errQuestion, dns.Type(respQ.Qtype))
}

// Compare the names case-insensitively, just like CoreDNS does.
if !strings.EqualFold(reqQ.Name, respQ.Name) {
return fmt.Errorf("%w: mismatched name %q", errQuestion, respQ.Name)
}

return nil
}
Loading

0 comments on commit 8956a92

Please sign in to comment.