diff --git a/intra/dns53/dot.go b/intra/dns53/dot.go index 8ef26fc8..ff3369e8 100644 --- a/intra/dns53/dot.go +++ b/intra/dns53/dot.go @@ -25,6 +25,8 @@ import ( "golang.org/x/net/context" ) +const usepool = true + type dot struct { ctx context.Context done context.CancelFunc @@ -37,6 +39,7 @@ type dot struct { c *dns.Client c3 *dns.Client // with ech rd *protect.RDial + pool *core.MultConnPool[uintptr] proxies ipn.Proxies // may be nil relay ipn.Proxy // may be nil est core.P2QuantileEstimator @@ -93,6 +96,7 @@ func NewTLSTransport(ctx context.Context, id, rawurl string, addrs []string, px proxies: px, rd: rd, relay: relay, + pool: core.NewMultConnPool[uintptr](ctx), est: core.NewP50Estimator(ctx), } ech := t.ech() @@ -153,13 +157,19 @@ func (t *dot) doQuery(pid string, q *dns.Msg) (response *dns.Msg, elapsed time.D return } -func (t *dot) tlsdial(rd protect.RDialer) (_ *dns.Conn, err error) { +func (t *dot) tlsdial(rd protect.RDialer) (_ *dns.Conn, who uintptr, err error) { + who = rd.Handle() + if c := t.fromPool(who); c != nil { + return c, who, nil + } + + var usingech bool var c net.Conn = nil // dot is always tcp addr := t.addr // t.addr may be ip or hostname if t.c3 != nil { // may be nil if ech is not available cfg := t.c3.TLSConfig // don't clone; may be modified by dialers.DialWithTls c, err = dialers.DialWithTls(rd, cfg, "tcp", addr) - log.W("dot: tlsdial: (%s) ech; err? %v", t.id, err) + usingech = true } if c == nil && core.IsNil(c) { // no ech or ech failed cfg := t.c.TLSConfig @@ -169,34 +179,64 @@ func (t *dot) tlsdial(rd protect.RDialer) (_ *dns.Conn, err error) { _ = c.SetDeadline(time.Now().Add(dottimeout)) // todo: higher timeout for if using proxy dialer // _ = c.SetDeadline(time.Now().Add(dottimeout * 2)) - return &dns.Conn{Conn: c, UDPSize: t.c.UDPSize}, err + return &dns.Conn{Conn: c, UDPSize: t.c.UDPSize}, who, err } else { if err == nil { - log.W("dot: tlsdial: (%s) nil conn/err for %s", t.id, addr) err = errNoNet } + log.W("dot: tlsdial: (%s) nil conn/err for %s, ech? %t; err? %v", + t.id, addr, usingech, err) } - return nil, err + return nil, who, err } -func (t *dot) pxdial(pid string) (*dns.Conn, error) { +func (t *dot) pxdial(pid string) (*dns.Conn, uintptr, error) { var px ipn.Proxy if t.relay != nil { // relay takes precedence px = t.relay } else if t.proxies != nil { // use proxy, if specified var err error if px, err = t.proxies.ProxyFor(pid); err != nil { - return nil, err + return nil, core.Nobody, err } } if px == nil { - return nil, dnsx.ErrNoProxyProvider + return nil, core.Nobody, dnsx.ErrNoProxyProvider } + pid = px.ID() log.V("dot: pxdial: (%s) using relay/proxy %s at %s", - t.id, px.ID(), px.GetAddr()) + t.id, pid, px.GetAddr()) + return t.tlsdial(px.Dialer()) } +func (t *dot) toPool(id uintptr, c *dns.Conn) { + if !usepool || id == core.Nobody { + clos(c) + return + } + ok := t.pool.Put(id, c) + log.V("dot: pool: (%s) put for %v; ok? %t", t.id, id, ok) +} + +func (t *dot) fromPool(id uintptr) (c *dns.Conn) { + if !usepool || id == core.Nobody { + return + } + + pooled := t.pool.Get(id) + if pooled == nil || core.IsNil(pooled) { + return + } + var ok bool + if c, ok = pooled.(*dns.Conn); !ok { // unlikely + clos(pooled) + return + } + log.V("dot: pool: (%s) got conn from %v; %d", t.id, id) + return +} + func clos(c net.Conn) { core.CloseConn(c) } @@ -210,18 +250,17 @@ func (t *dot) sendRequest(pid string, q *dns.Msg) (ans *dns.Msg, elapsed time.Du } var conn *dns.Conn + var who uintptr userelay := t.relay != nil useproxy := len(pid) != 0 // pid == dnsx.NetNoProxy => ipn.Base if useproxy || userelay { - conn, err = t.pxdial(pid) + conn, who, err = t.pxdial(pid) } else { // ref dns.Client.Dial - conn, err = t.tlsdial(t.rd) + conn, who, err = t.tlsdial(t.rd) } if err == nil { - // FIXME: conn pooling using t.c.Dial + ExchangeWithConn ans, elapsed, err = t.c.ExchangeWithConn(q, conn) - clos(conn) } // fallthrough raddr := remoteAddrIfAny(conn) @@ -232,6 +271,7 @@ func (t *dot) sendRequest(pid string, q *dns.Msg) (ans *dns.Msg, elapsed time.Du t.id, xdns.Size(q), xdns.EDNS0PadLen(q), err, ok, t.host, raddr) qerr = dnsx.NewSendFailedQueryError(err) } else { + t.toPool(who, conn) // or close dialers.Confirm2(t.host, raddr) } return @@ -321,3 +361,10 @@ func url2addr(url string) string { } return url } + +func logev(err error) log.LogFn { + if err != nil { + return log.E + } + return log.V +}