diff --git a/transport_tls.go b/transport_tls.go index dcd014f..93eeead 100644 --- a/transport_tls.go +++ b/transport_tls.go @@ -93,21 +93,28 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg t.access.Lock() conn := t.connections.PopFront() t.access.Unlock() - if conn == nil { - tcpConn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr) - if err != nil { - return nil, err + if conn != nil { + response, err := t.exchange(message, conn) + if err == nil { + return response, nil } - tlsConn := tls.Client(tcpConn, &tls.Config{ - ServerName: t.serverAddr.AddrString(), - }) - err = tlsConn.HandshakeContext(ctx) - if err != nil { - tcpConn.Close() - return nil, err - } - conn = &tlsDNSConn{Conn: tlsConn} } + tcpConn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(tcpConn, &tls.Config{ + ServerName: t.serverAddr.AddrString(), + }) + err = tlsConn.HandshakeContext(ctx) + if err != nil { + tcpConn.Close() + return nil, err + } + return t.exchange(message, &tlsDNSConn{Conn: tlsConn}) +} + +func (t *TLSTransport) exchange(message *dns.Msg, conn *tlsDNSConn) (*dns.Msg, error) { messageId := message.Id conn.queryId++ message.Id = conn.queryId