From 8fe1437c54bb32ae96bb6830962d0a679fd18753 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Mon, 9 Sep 2024 16:19:43 -0400 Subject: [PATCH 01/24] working input demultiplexor with tls --- src/cli/worker_manager.go | 150 ++++++++++++++++++++++++++++---------- src/zdns/types.go | 27 +++++++ 2 files changed, 137 insertions(+), 40 deletions(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index d93dedb2..804beab1 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -432,6 +432,44 @@ func populateLocalAddresses(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Res return config, nil } +type WorkerPools struct { + WorkerPools []chan *InputLineWithNameServer +} + +func NewWorkerPools(numPools int) *WorkerPools { + workerPools := make([]chan *InputLineWithNameServer, numPools) + for i := 0; i < numPools; i++ { + workerPools[i] = make(chan *InputLineWithNameServer) + } + return &WorkerPools{WorkerPools: workerPools} +} + +type InputLineWithNameServer struct { + Line string + NameServer *zdns.NameServer +} + +// inputDeMultiplxer is a single goroutine that reads from the input channel and sends the input to the appropriate worker pool channel +// The goal is that a query for a single name server will consistently go to the same worker pool which 1+ threads will read from +// This is especially useful for HTTPS/TLS/TCP based lookups where repeating the initial handshakes would be wasteful +func inputDeMultiplexer(nameServers []zdns.NameServer, inChan <-chan string, workerPools *WorkerPools, wg *sync.WaitGroup) error { + wg.Add(1) + defer wg.Done() + // defer closing the worker pool chans + defer func() { + for _, pool := range workerPools.WorkerPools { + close(pool) + } + }() + for line := range inChan { + nsIndex := rand.Intn(len(nameServers)) + randomNS := nameServers[nsIndex] + chanID := nsIndex % len(workerPools.WorkerPools) + workerPools.WorkerPools[chanID] <- &InputLineWithNameServer{Line: line, NameServer: &randomNS} + } + return nil +} + func Run(gc CLIConf) { gc = *populateCLIConfig(&gc) resolverConfig := populateResolverConfig(&gc) @@ -468,6 +506,27 @@ func Run(gc CLIConf) { log.Fatal("Output handler is nil") } + nameServers := util.Concat(resolverConfig.ExternalNameServersV4, resolverConfig.ExternalNameServersV6, resolverConfig.RootNameServersV4, resolverConfig.RootNameServersV6) + // de-dupe + nsLookupMap := make(map[uint32]struct{}) + uniqNameServers := make([]zdns.NameServer, 0, len(nameServers)) + for _, ns := range nameServers { + hash, err := ns.Hash() + if err != nil { + log.Fatalf("could not hash name server %s: %v", ns.String(), err) + } + if _, ok := nsLookupMap[hash]; !ok { + nsLookupMap[hash] = struct{}{} + uniqNameServers = append(uniqNameServers, ns) + } + } + numberOfWorkerPools := len(uniqNameServers) + if gc.Threads < numberOfWorkerPools { + // multiple threads can share a channel, but we can't have more channels than threads + numberOfWorkerPools = gc.Threads + } + workerPools := NewWorkerPools(numberOfWorkerPools) + // Use handlers to populate the input and output/results channel go func() { inErr := inHandler.FeedChannel(inChan, &routineWG) @@ -475,6 +534,12 @@ func Run(gc CLIConf) { log.Fatal(fmt.Sprintf("could not feed input channel: %v", inErr)) } }() + go func() { + plexErr := inputDeMultiplexer(uniqNameServers, inChan, workerPools, &routineWG) + if plexErr != nil { + log.Fatal(fmt.Sprintf("could not de-multiplex input channel: %v", plexErr)) + } + }() go func() { outErr := outHandler.WriteResults(outChan, &routineWG) if outErr != nil { @@ -490,8 +555,9 @@ func Run(gc CLIConf) { // create shared cache for all threads to share for i := 0; i < gc.Threads; i++ { i := i + channelID := i % len(workerPools.WorkerPools) go func(threadID int) { - initWorkerErr := doLookupWorker(&gc, resolverConfig, inChan, outChan, metaChan, &lookupWG) + initWorkerErr := doLookupWorker(&gc, resolverConfig, workerPools.WorkerPools[channelID], outChan, metaChan, &lookupWG) if initWorkerErr != nil { log.Fatalf("could not start lookup worker #%d: %v", i, initWorkerErr) } @@ -542,7 +608,7 @@ func Run(gc CLIConf) { } // doLookupWorker is a single worker thread that processes lookups from the input channel. It calls wg.Done when it is finished. -func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, input <-chan string, output chan<- string, metaChan chan<- routineMetadata, wg *sync.WaitGroup) error { +func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, input <-chan *InputLineWithNameServer, output chan<- string, metaChan chan<- routineMetadata, wg *sync.WaitGroup) error { defer wg.Done() resolver, err := zdns.InitResolver(rc) if err != nil { @@ -552,44 +618,7 @@ func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, input <-chan string, o metadata.Status = make(map[zdns.Status]int) for line := range input { // we'll process each module sequentially, parallelism is per-domain - res := zdns.Result{Results: make(map[string]zdns.SingleModuleResult, len(gc.ActiveModules))} - // get the fields that won't change for each lookup module - rawName := "" - var nameServer *zdns.NameServer - var nameServers []zdns.NameServer - nameServerString := "" - var rank int - var entryMetadata string - if gc.AlexaFormat { - rawName, rank = parseAlexa(line) - res.AlexaRank = rank - } else if gc.MetadataFormat { - rawName, entryMetadata = parseMetadataInputLine(line) - res.Metadata = entryMetadata - } else if gc.NameServerMode { - nameServers, err = convertNameServerStringToNameServer(line, rc.IPVersionMode, rc.DNSOverTLS, rc.DNSOverHTTPS) - if err != nil { - log.Fatal("unable to parse name server: ", line) - } - if len(nameServers) == 0 { - log.Fatal("no name servers found in line: ", line) - } - // if user provides a domain name for the name server (one.one.one.one) we'll pick one of the IPs at random - nameServer = &nameServers[rand.Intn(len(nameServers))] - } else { - rawName, nameServerString = parseNormalInputLine(line) - if len(nameServerString) != 0 { - nameServers, err = convertNameServerStringToNameServer(nameServerString, rc.IPVersionMode, rc.DNSOverTLS, rc.DNSOverHTTPS) - if err != nil { - log.Fatal("unable to parse name server: ", line) - } - if len(nameServers) == 0 { - log.Fatal("no name servers found in line: ", line) - } - // if user provides a domain name for the name server (one.one.one.one) we'll pick one of the IPs at random - nameServer = &nameServers[rand.Intn(len(nameServers))] - } - } + res, rawName, nameServer := parseInputLine(gc, rc, line) res.Name = rawName // handle per-module lookups for moduleName, module := range gc.ActiveModules { @@ -650,6 +679,47 @@ func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, input <-chan string, o return nil } +func parseInputLine(gc *CLIConf, rc *zdns.ResolverConfig, line *InputLineWithNameServer) (*zdns.Result, string, *zdns.NameServer) { + res := zdns.Result{Results: make(map[string]zdns.SingleModuleResult, len(gc.ActiveModules))} + // get the fields that won't change for each lookup module + rawName := "" + nameServer := line.NameServer + nameServerString := "" + var rank int + var entryMetadata string + if gc.AlexaFormat { + rawName, rank = parseAlexa(line.Line) + res.AlexaRank = rank + } else if gc.MetadataFormat { + rawName, entryMetadata = parseMetadataInputLine(line.Line) + res.Metadata = entryMetadata + } else if gc.NameServerMode { + nameServers, err := convertNameServerStringToNameServer(line.Line, rc.IPVersionMode, rc.DNSOverTLS, rc.DNSOverHTTPS) + if err != nil { + log.Fatal("unable to parse name server: ", line.Line) + } + if len(nameServers) == 0 { + log.Fatal("no name servers found in line: ", line.Line) + } + // if user provides a domain name for the name server (one.one.one.one) we'll pick one of the IPs at random + nameServer = &nameServers[rand.Intn(len(nameServers))] + } else { + rawName, nameServerString = parseNormalInputLine(line.Line) + if len(nameServerString) != 0 { + nameServers, err := convertNameServerStringToNameServer(nameServerString, rc.IPVersionMode, rc.DNSOverTLS, rc.DNSOverHTTPS) + if err != nil { + log.Fatal("unable to parse name server: ", line.Line) + } + if len(nameServers) == 0 { + log.Fatal("no name servers found in line: ", line.Line) + } + // if user provides a domain name for the name server (one.one.one.one) we'll pick one of the IPs at random + nameServer = &nameServers[rand.Intn(len(nameServers))] + } + } + return &res, rawName, nameServer +} + func parseAlexa(line string) (string, int) { s := strings.SplitN(line, ",", 2) rank, err := strconv.Atoi(s[0]) diff --git a/src/zdns/types.go b/src/zdns/types.go index 309df459..9d0ce8e7 100644 --- a/src/zdns/types.go +++ b/src/zdns/types.go @@ -15,6 +15,8 @@ package zdns import ( "fmt" + "github.com/pkg/errors" + "hash/fnv" "net" "github.com/zmap/zdns/src/internal/util" @@ -122,6 +124,31 @@ func (ns *NameServer) String() string { return "" } +func (ns *NameServer) Hash() (uint32, error) { + h := fnv.New32a() + + // Hash the IP address + _, err := h.Write(ns.IP) + if err != nil { + return 0, errors.Wrap(err, "unable to hash IP address") + } + + // Hash the Port + portBytes := []byte{byte(ns.Port >> 8), byte(ns.Port & 0xff)} + _, err = h.Write(portBytes) + if err != nil { + return 0, errors.Wrap(err, "unable to hash port") + } + + // Hash the DomainName + _, err = h.Write([]byte(ns.DomainName)) + if err != nil { + return 0, errors.Wrap(err, "unable to hash domain name") + } + + return h.Sum32(), nil +} + func (ns *NameServer) PopulateDefaultPort(usingDoT, usingDoH bool) { if ns.Port != 0 { return From 0957e6f65d529758d8f2ebcea959079a6dfacb18 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 11:13:19 -0400 Subject: [PATCH 02/24] handled tcp conns --- src/zdns/lookup.go | 28 ++++++++++++++++++++-------- src/zdns/resolver.go | 38 ++++++++++++++++++++++++++------------ 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index 0763c3b3..8dd7226d 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -453,7 +453,7 @@ func (r *Resolver) retryingLookup(ctx context.Context, q Question, nameServer *N } else if r.dnsOverTLSEnabled { result, status, err = doDoTLookup(ctx, connInfo, q, nameServer, recursive, r.ednsOptions, r.dnsSecEnabled, r.checkingDisabledBit) } else { - result, status, err = wireLookup(ctx, connInfo.udpClient, connInfo.tcpClient, connInfo.conn, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) + result, status, err = wireLookup(ctx, connInfo.udpClient, connInfo.tcpClient, connInfo.udpConn, connInfo.tcpConn, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) } if status != StatusTimeout || i == r.retries { return result, status, i + 1, err @@ -615,7 +615,7 @@ func doDoHLookup(ctx context.Context, httpClient *http.Client, q Question, nameS // wireLookup performs a DNS lookup on-the-wire with the given parameters // Attempts a UDP lookup first, then falls back to TCP if necessary (if the UDP response encounters an error or is truncated) -func wireLookup(ctx context.Context, udp *dns.Client, tcp *dns.Client, conn *dns.Conn, q Question, nameServer *NameServer, ednsOptions []dns.EDNS0, recursive, dnssec, checkingDisabled bool) (SingleQueryResult, Status, error) { +func wireLookup(ctx context.Context, udp *dns.Client, tcp *dns.Client, udpConn, tcpConn *dns.Conn, q Question, nameServer *NameServer, ednsOptions []dns.EDNS0, recursive, dnssec, checkingDisabled bool) (SingleQueryResult, Status, error) { res := SingleQueryResult{Answers: []interface{}{}, Authorities: []interface{}{}, Additional: []interface{}{}} res.Resolver = nameServer.String() @@ -634,23 +634,35 @@ func wireLookup(ctx context.Context, udp *dns.Client, tcp *dns.Client, conn *dns var err error if udp != nil { res.Protocol = "udp" - if conn != nil { + if udpConn != nil { dst, _ := net.ResolveUDPAddr("udp", nameServer.String()) - r, _, err = udp.ExchangeWithConnToContext(ctx, m, conn, dst) + r, _, err = udp.ExchangeWithConnToContext(ctx, m, udpConn, dst) } else { r, _, err = udp.ExchangeContext(ctx, m, nameServer.String()) } // if record comes back truncated, but we have a TCP connection, try again with that if r != nil && (r.Truncated || r.Rcode == dns.RcodeBadTrunc) { if tcp != nil { - return wireLookup(ctx, nil, tcp, conn, q, nameServer, ednsOptions, recursive, dnssec, checkingDisabled) + return wireLookup(ctx, nil, tcp, udpConn, tcpConn, q, nameServer, ednsOptions, recursive, dnssec, checkingDisabled) } else { return res, StatusTruncated, err } } - } else { - res.Protocol = "tcp" - r, _, err = tcp.ExchangeContext(ctx, m, nameServer.String()) + } else if tcp != nil { + // TCP + if tcpConn != nil && tcpConn.RemoteAddr != nil && tcpConn.RemoteAddr.String() == nameServer.String() { + // we have a connection to this nameserver, use it + res.Protocol = "tcp" + var addr *net.TCPAddr + addr, err = net.ResolveTCPAddr("tcp", nameServer.String()) + if err != nil { + return SingleQueryResult{}, StatusError, fmt.Errorf("could not resolve TCP address %s: %v", nameServer.String(), err) + } + r, _, err = tcp.ExchangeWithConnToContext(ctx, m, tcpConn, addr) + } else { + res.Protocol = "tcp" + r, _, err = tcp.ExchangeContext(ctx, m, nameServer.String()) + } } if err != nil || r == nil { if nerr, ok := err.(net.Error); ok { diff --git a/src/zdns/resolver.go b/src/zdns/resolver.go index 6d485bba..989cb8b4 100644 --- a/src/zdns/resolver.go +++ b/src/zdns/resolver.go @@ -233,7 +233,8 @@ func NewResolverConfig() *ResolverConfig { type ConnectionInfo struct { udpClient *dns.Client tcpClient *dns.Client - conn *dns.Conn // for socket re-use + udpConn *dns.Conn // for socket re-use with UDP + tcpConn *dns.Conn // for socket re-use with TCP, if RemoteAddr doesn't change, we don't re-handshake httpsClient *http.Client // for DoH tlsConn *dns.Conn // for DoT tlsHandshake *tls.ServerHandshake // for DoT, used to print TLS handshake to user @@ -370,7 +371,7 @@ func (r *Resolver) getConnectionInfo(nameServer *NameServer) (*ConnectionInfo, e // what local addresses should we use? isNSIPv6 := util.IsIPv6(&nameServer.IP) isLoopback := nameServer.IP.IsLoopback() - // check if we have a pre-existing conn info + // check if we have a pre-existing udpConn info if isNSIPv6 && isLoopback && r.connInfoIPv6Loopback != nil { return r.connInfoIPv6Loopback, nil } else if isNSIPv6 && !isLoopback && r.connInfoIPv6Internet != nil { @@ -439,8 +440,8 @@ func (r *Resolver) getConnectionInfo(nameServer *NameServer) (*ConnectionInfo, e if err != nil { return nil, fmt.Errorf("unable to create UDP connection: %w", err) } - connInfo.conn = new(dns.Conn) - connInfo.conn.Conn = conn + connInfo.udpConn = new(dns.Conn) + connInfo.udpConn.Conn = conn } usingUDP := r.transportMode == UDPOrTCP || r.transportMode == UDPOnly @@ -462,6 +463,19 @@ func (r *Resolver) getConnectionInfo(nameServer *NameServer) (*ConnectionInfo, e LocalAddr: &net.TCPAddr{IP: connInfo.localAddr}, } } + if r.transportMode == TCPOnly && r.shouldRecycleSockets { + // create persistent TCP connection to nameserver + if connInfo.tcpConn == nil || connInfo.tcpConn.RemoteAddr != nil || connInfo.tcpConn.RemoteAddr.String() != nameServer.String() { + // RemoteAddr has changed, we need to re-handshake + conn, err := net.DialTCP("tcp", &net.TCPAddr{IP: connInfo.localAddr}, &net.TCPAddr{IP: nameServer.IP, Port: int(nameServer.Port)}) + if err != nil { + return nil, fmt.Errorf("unable to create TCP connection for nameserver %s: %w", nameServer.String(), err) + } + connInfo.tcpConn = new(dns.Conn) + connInfo.tcpConn.Conn = conn + connInfo.tcpConn.RemoteAddr = &net.TCPAddr{IP: nameServer.IP, Port: int(nameServer.Port)} + } + } if r.dnsOverHTTPSEnabled { // Create a http.Client with the custom transport connInfo.httpsClient = &http.Client{ @@ -561,23 +575,23 @@ func (r *Resolver) IterativeLookup(q *Question) (*SingleQueryResult, Trace, Stat // Close cleans up any resources used by the resolver. This should be called when the resolver is no longer needed. // Lookup will panic if called after Close. func (r *Resolver) Close() { - if r.connInfoIPv4Internet != nil && r.connInfoIPv4Internet.conn != nil { - if err := r.connInfoIPv4Internet.conn.Close(); err != nil { + if r.connInfoIPv4Internet != nil && r.connInfoIPv4Internet.udpConn != nil { + if err := r.connInfoIPv4Internet.udpConn.Close(); err != nil { log.Errorf("error closing IPv4 connection: %v", err) } } - if r.connInfoIPv6Internet != nil && r.connInfoIPv6Internet.conn != nil { - if err := r.connInfoIPv6Internet.conn.Close(); err != nil { + if r.connInfoIPv6Internet != nil && r.connInfoIPv6Internet.udpConn != nil { + if err := r.connInfoIPv6Internet.udpConn.Close(); err != nil { log.Errorf("error closing IPv6 connection: %v", err) } } - if r.connInfoIPv4Loopback != nil && r.connInfoIPv4Loopback.conn != nil { - if err := r.connInfoIPv4Loopback.conn.Close(); err != nil { + if r.connInfoIPv4Loopback != nil && r.connInfoIPv4Loopback.udpConn != nil { + if err := r.connInfoIPv4Loopback.udpConn.Close(); err != nil { log.Errorf("error closing IPv4 loopback connection: %v", err) } } - if r.connInfoIPv6Loopback != nil && r.connInfoIPv6Loopback.conn != nil { - if err := r.connInfoIPv6Loopback.conn.Close(); err != nil { + if r.connInfoIPv6Loopback != nil && r.connInfoIPv6Loopback.udpConn != nil { + if err := r.connInfoIPv6Loopback.udpConn.Close(); err != nil { log.Errorf("error closing IPv6 loopback connection: %v", err) } } From d425d4002f8d4499488ff30ba5c9c3fa9559a25b Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 11:32:52 -0400 Subject: [PATCH 03/24] handle HTTPS de-multiplexing --- src/cli/worker_manager.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index 804beab1..b0ec2066 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -451,7 +451,8 @@ type InputLineWithNameServer struct { // inputDeMultiplxer is a single goroutine that reads from the input channel and sends the input to the appropriate worker pool channel // The goal is that a query for a single name server will consistently go to the same worker pool which 1+ threads will read from -// This is especially useful for HTTPS/TLS/TCP based lookups where repeating the initial handshakes would be wasteful +// This is especially useful for TLS/TCP based lookups where repeating the initial handshakes would be wasteful +// For HTTPS conns, they depend on the domain name rather than the IP address, so we need to ensure that all queries for a domain name go to the same channel with HTTPS func inputDeMultiplexer(nameServers []zdns.NameServer, inChan <-chan string, workerPools *WorkerPools, wg *sync.WaitGroup) error { wg.Add(1) defer wg.Done() @@ -520,6 +521,19 @@ func Run(gc CLIConf) { uniqNameServers = append(uniqNameServers, ns) } } + // DoH lookups only depend on domain name, remove nameservers with duplicate domains + // We don't want the deMultiplexer to send the same domain (with different IPs) to different worker pools + if gc.DNSOverHTTPS { + nsDomainLookupMap := make(map[string]struct{}) + uniqueDomainNSes := make([]zdns.NameServer, 0, len(uniqNameServers)) + for _, ns := range uniqNameServers { + if _, ok := nsDomainLookupMap[ns.DomainName]; !ok { + nsDomainLookupMap[ns.DomainName] = struct{}{} + uniqueDomainNSes = append(uniqueDomainNSes, ns) + } + } + uniqNameServers = uniqueDomainNSes + } numberOfWorkerPools := len(uniqNameServers) if gc.Threads < numberOfWorkerPools { // multiple threads can share a channel, but we can't have more channels than threads From 69c51069168b73a570dea61e194a1e0f415c9008 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 11:34:28 -0400 Subject: [PATCH 04/24] lint --- src/cli/worker_manager.go | 3 ++- src/zdns/types.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index b0ec2066..464bb15d 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -511,8 +511,9 @@ func Run(gc CLIConf) { // de-dupe nsLookupMap := make(map[uint32]struct{}) uniqNameServers := make([]zdns.NameServer, 0, len(nameServers)) + var hash uint32 for _, ns := range nameServers { - hash, err := ns.Hash() + hash, err = ns.Hash() if err != nil { log.Fatalf("could not hash name server %s: %v", ns.String(), err) } diff --git a/src/zdns/types.go b/src/zdns/types.go index 9d0ce8e7..701426f7 100644 --- a/src/zdns/types.go +++ b/src/zdns/types.go @@ -15,10 +15,11 @@ package zdns import ( "fmt" - "github.com/pkg/errors" "hash/fnv" "net" + "github.com/pkg/errors" + "github.com/zmap/zdns/src/internal/util" ) From c45ebc8bc7ced5267252acbba2bf86746fb9aa59 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 11:43:59 -0400 Subject: [PATCH 05/24] improved error msg if user only supplies IPv4 addresses and we fail config validation --- src/zdns/resolver.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zdns/resolver.go b/src/zdns/resolver.go index 989cb8b4..eda1620d 100644 --- a/src/zdns/resolver.go +++ b/src/zdns/resolver.go @@ -113,11 +113,11 @@ func (rc *ResolverConfig) Validate() error { // External Nameservers if rc.IPVersionMode != IPv6Only && len(rc.ExternalNameServersV4) == 0 { // If IPv4 is supported, we require at least one IPv4 external nameserver - return errors.New("must have at least one external IPv4 name server if IPv4 mode is enabled") + return errors.New("must have at least one external IPv4 name server if IPv4 mode is enabled. Use IPv6 only if you don't have IPv4 nameservers") } if rc.IPVersionMode != IPv4Only && len(rc.ExternalNameServersV6) == 0 { // If IPv6 is supported, we require at least one IPv6 external nameserver - return errors.New("must have at least one external IPv6 name server if IPv6 mode is enabled") + return errors.New("must have at least one external IPv6 name server if IPv6 mode is enabled. Use IPv4 only if you don't have IPv6 nameservers") } // Validate all nameservers have ports and are valid IPs From 67e81497588252ef782ac72c4aa4efad493b4dfd Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 12:02:19 -0400 Subject: [PATCH 06/24] added AXFR edge case handling --- src/cli/worker_manager.go | 13 +++++++++++++ src/modules/spf/spf.go | 4 ++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index 464bb15d..bfb2756d 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -637,6 +637,14 @@ func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, input <-chan *InputLin res.Name = rawName // handle per-module lookups for moduleName, module := range gc.ActiveModules { + if moduleName == "AXFR" { + // special case, AXFR has its own nameserver handling. We'll only take nameservers if the user provides it + // not the "suggestion" from the de-multiplexor + if nameServer.String() == line.NameServer.String() { + // this name server is the suggested one from the de-multiplexor, we'll remove it + nameServer = nil + } + } var innerRes interface{} var trace zdns.Trace var status zdns.Status @@ -698,6 +706,11 @@ func parseInputLine(gc *CLIConf, rc *zdns.ResolverConfig, line *InputLineWithNam res := zdns.Result{Results: make(map[string]zdns.SingleModuleResult, len(gc.ActiveModules))} // get the fields that won't change for each lookup module rawName := "" + // this is the name server "suggested" by the de-multiplexor. The goal is that if + // 1) the user doesn't provide a nameserver + // 2) we're in external lookup mode + // then we'll use the suggestion. This is to avoid the overhead of re-handshaking for each lookup + // it's overwritten if the user provides a nameserver as part of the input line below nameServer := line.NameServer nameServerString := "" var rank int diff --git a/src/modules/spf/spf.go b/src/modules/spf/spf.go index bb1aa48b..d5a2a9a7 100644 --- a/src/modules/spf/spf.go +++ b/src/modules/spf/spf.go @@ -48,8 +48,8 @@ func (spfMod *SpfLookupModule) CLIInit(gc *cli.CLIConf, rc *zdns.ResolverConfig) return spfMod.BasicLookupModule.CLIInit(gc, rc) } -func (spfMod *SpfLookupModule) Lookup(r *zdns.Resolver, name string, resolver *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error) { - innerRes, trace, status, err := spfMod.BasicLookupModule.Lookup(r, name, resolver) +func (spfMod *SpfLookupModule) Lookup(r *zdns.Resolver, name string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error) { + innerRes, trace, status, err := spfMod.BasicLookupModule.Lookup(r, name, nameServer) castedInnerRes, ok := innerRes.(*zdns.SingleQueryResult) if !ok { return nil, trace, status, errors.New("lookup didn't return a single query result type") From fade130edf3588db99bdc2518c25b7664b130d7c Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 12:12:03 -0400 Subject: [PATCH 07/24] added comments --- src/cli/worker_manager.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index bfb2756d..beb1b9a6 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -432,6 +432,8 @@ func populateLocalAddresses(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Res return config, nil } +// WorkerPools are a collection of channels that workers can read from +// 1+ threads will read from a pooled channel, and the inputDeMultiplexer will send input to the appropriate channel type WorkerPools struct { WorkerPools []chan *InputLineWithNameServer } @@ -444,6 +446,9 @@ func NewWorkerPools(numPools int) *WorkerPools { return &WorkerPools{WorkerPools: workerPools} } +// InputLineWithNameServer is a struct that contains a line of input and the name server to use for the lookup +// This name server is a "suggestion", --iterative lookups will ignore it as well as AXFR lookups +// The goal is to attempt to send all queries for a single name server to the same worker pool type InputLineWithNameServer struct { Line string NameServer *zdns.NameServer @@ -451,8 +456,7 @@ type InputLineWithNameServer struct { // inputDeMultiplxer is a single goroutine that reads from the input channel and sends the input to the appropriate worker pool channel // The goal is that a query for a single name server will consistently go to the same worker pool which 1+ threads will read from -// This is especially useful for TLS/TCP based lookups where repeating the initial handshakes would be wasteful -// For HTTPS conns, they depend on the domain name rather than the IP address, so we need to ensure that all queries for a domain name go to the same channel with HTTPS +// This is especially useful for TLS/TCP/HTTPS based lookups where repeating the initial handshakes would be wasteful func inputDeMultiplexer(nameServers []zdns.NameServer, inChan <-chan string, workerPools *WorkerPools, wg *sync.WaitGroup) error { wg.Add(1) defer wg.Done() @@ -570,6 +574,7 @@ func Run(gc CLIConf) { // create shared cache for all threads to share for i := 0; i < gc.Threads; i++ { i := i + // assign each worker to a worker pool, we'll loop around if we have more workers than pools channelID := i % len(workerPools.WorkerPools) go func(threadID int) { initWorkerErr := doLookupWorker(&gc, resolverConfig, workerPools.WorkerPools[channelID], outChan, metaChan, &lookupWG) From 0372d8dee7698373e694a3dcff4d6077d1058fbb Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 15:26:29 -0400 Subject: [PATCH 08/24] if TCP connection is closed, re-open it --- src/zdns/lookup.go | 98 ++++++++++++++++++++++++++++++-------------- src/zdns/resolver.go | 22 ++++++---- 2 files changed, 83 insertions(+), 37 deletions(-) diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index 8dd7226d..3f288726 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -452,8 +452,16 @@ func (r *Resolver) retryingLookup(ctx context.Context, q Question, nameServer *N result, status, err = doDoHLookup(ctx, connInfo.httpsClient, q, nameServer, recursive, r.ednsOptions, r.dnsSecEnabled, r.checkingDisabledBit) } else if r.dnsOverTLSEnabled { result, status, err = doDoTLookup(ctx, connInfo, q, nameServer, recursive, r.ednsOptions, r.dnsSecEnabled, r.checkingDisabledBit) + } else if connInfo.udpClient != nil { + result, status, err = wireLookupUDP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) + if status == StatusTruncated && connInfo.tcpClient != nil { + // result truncated, try again with TCP + result, status, err = wireLookupTCP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) + } + } else if connInfo.tcpClient != nil { + result, status, err = wireLookupTCP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) } else { - result, status, err = wireLookup(ctx, connInfo.udpClient, connInfo.tcpClient, connInfo.udpConn, connInfo.tcpConn, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) + return SingleQueryResult{}, StatusError, 0, errors.New("no connection info for nameserver") } if status != StatusTimeout || i == r.retries { return result, status, i + 1, err @@ -613,9 +621,8 @@ func doDoHLookup(ctx context.Context, httpClient *http.Client, q Question, nameS return constructSingleQueryResultFromDNSMsg(res, r) } -// wireLookup performs a DNS lookup on-the-wire with the given parameters -// Attempts a UDP lookup first, then falls back to TCP if necessary (if the UDP response encounters an error or is truncated) -func wireLookup(ctx context.Context, udp *dns.Client, tcp *dns.Client, udpConn, tcpConn *dns.Conn, q Question, nameServer *NameServer, ednsOptions []dns.EDNS0, recursive, dnssec, checkingDisabled bool) (SingleQueryResult, Status, error) { +// wireLookupTCP performs a DNS lookup on-the-wire over TCP with the given parameters +func wireLookupTCP(ctx context.Context, connInfo *ConnectionInfo, q Question, nameServer *NameServer, ednsOptions []dns.EDNS0, recursive, dnssec, checkingDisabled bool) (SingleQueryResult, Status, error) { res := SingleQueryResult{Answers: []interface{}{}, Authorities: []interface{}{}, Additional: []interface{}{}} res.Resolver = nameServer.String() @@ -632,37 +639,68 @@ func wireLookup(ctx context.Context, udp *dns.Client, tcp *dns.Client, udpConn, var r *dns.Msg var err error - if udp != nil { - res.Protocol = "udp" - if udpConn != nil { - dst, _ := net.ResolveUDPAddr("udp", nameServer.String()) - r, _, err = udp.ExchangeWithConnToContext(ctx, m, udpConn, dst) - } else { - r, _, err = udp.ExchangeContext(ctx, m, nameServer.String()) + if connInfo.tcpConn != nil && connInfo.tcpConn.RemoteAddr != nil && connInfo.tcpConn.RemoteAddr.String() == nameServer.String() { + // we have a connection to this nameserver, use it + res.Protocol = "tcp" + var addr *net.TCPAddr + addr, err = net.ResolveTCPAddr("tcp", nameServer.String()) + if err != nil { + return SingleQueryResult{}, StatusError, fmt.Errorf("could not resolve TCP address %s: %v", nameServer.String(), err) } - // if record comes back truncated, but we have a TCP connection, try again with that - if r != nil && (r.Truncated || r.Rcode == dns.RcodeBadTrunc) { - if tcp != nil { - return wireLookup(ctx, nil, tcp, udpConn, tcpConn, q, nameServer, ednsOptions, recursive, dnssec, checkingDisabled) - } else { - return res, StatusTruncated, err + r, _, err = connInfo.tcpClient.ExchangeWithConnToContext(ctx, m, connInfo.tcpConn, addr) + if err != nil && err.Error() == "EOF" { + // EOF error means the connection was closed, we'll re-open a connection and re-handshake + err = getNewTCPConn(nameServer, connInfo) + if err != nil { + return SingleQueryResult{}, StatusError, fmt.Errorf("could not get new TCP connection to nameserver %s: %v", nameServer.String(), err) } + return wireLookupTCP(ctx, connInfo, q, nameServer, ednsOptions, recursive, dnssec, checkingDisabled) } - } else if tcp != nil { - // TCP - if tcpConn != nil && tcpConn.RemoteAddr != nil && tcpConn.RemoteAddr.String() == nameServer.String() { - // we have a connection to this nameserver, use it - res.Protocol = "tcp" - var addr *net.TCPAddr - addr, err = net.ResolveTCPAddr("tcp", nameServer.String()) - if err != nil { - return SingleQueryResult{}, StatusError, fmt.Errorf("could not resolve TCP address %s: %v", nameServer.String(), err) + } else { + // no pre-existing connection, create a ephemeral one + res.Protocol = "tcp" + r, _, err = connInfo.tcpClient.ExchangeContext(ctx, m, nameServer.String()) + } + if err != nil || r == nil { + if nerr, ok := err.(net.Error); ok { + if nerr.Timeout() { + return res, StatusTimeout, nil } - r, _, err = tcp.ExchangeWithConnToContext(ctx, m, tcpConn, addr) - } else { - res.Protocol = "tcp" - r, _, err = tcp.ExchangeContext(ctx, m, nameServer.String()) } + return res, StatusError, err + } + + return constructSingleQueryResultFromDNSMsg(res, r) +} + +// wireLookupUDP performs a DNS lookup on-the-wire over UDP with the given parameters +func wireLookupUDP(ctx context.Context, connInfo *ConnectionInfo, q Question, nameServer *NameServer, ednsOptions []dns.EDNS0, recursive, dnssec, checkingDisabled bool) (SingleQueryResult, Status, error) { + res := SingleQueryResult{Answers: []interface{}{}, Authorities: []interface{}{}, Additional: []interface{}{}} + res.Resolver = nameServer.String() + res.Protocol = "udp" + + m := new(dns.Msg) + m.SetQuestion(dotName(q.Name), q.Type) + m.Question[0].Qclass = q.Class + m.RecursionDesired = recursive + m.CheckingDisabled = checkingDisabled + + m.SetEdns0(1232, dnssec) + if ednsOpt := m.IsEdns0(); ednsOpt != nil { + ednsOpt.Option = append(ednsOpt.Option, ednsOptions...) + } + + var r *dns.Msg + var err error + if connInfo.udpConn != nil { + dst, _ := net.ResolveUDPAddr("udp", nameServer.String()) + r, _, err = connInfo.udpClient.ExchangeWithConnToContext(ctx, m, connInfo.udpConn, dst) + } else { + r, _, err = connInfo.udpClient.ExchangeContext(ctx, m, nameServer.String()) + } + // if record comes back truncated, but we have a TCP connection, try again with that + if r != nil && (r.Truncated || r.Rcode == dns.RcodeBadTrunc) { + return res, StatusTruncated, err } if err != nil || r == nil { if nerr, ok := err.(net.Error); ok { diff --git a/src/zdns/resolver.go b/src/zdns/resolver.go index eda1620d..4f751811 100644 --- a/src/zdns/resolver.go +++ b/src/zdns/resolver.go @@ -464,16 +464,12 @@ func (r *Resolver) getConnectionInfo(nameServer *NameServer) (*ConnectionInfo, e } } if r.transportMode == TCPOnly && r.shouldRecycleSockets { - // create persistent TCP connection to nameserver if connInfo.tcpConn == nil || connInfo.tcpConn.RemoteAddr != nil || connInfo.tcpConn.RemoteAddr.String() != nameServer.String() { - // RemoteAddr has changed, we need to re-handshake - conn, err := net.DialTCP("tcp", &net.TCPAddr{IP: connInfo.localAddr}, &net.TCPAddr{IP: nameServer.IP, Port: int(nameServer.Port)}) + // need to re-handshake + err := getNewTCPConn(nameServer, connInfo) if err != nil { - return nil, fmt.Errorf("unable to create TCP connection for nameserver %s: %w", nameServer.String(), err) + return nil, errors.Wrap(err, "unable to create TCP connection") } - connInfo.tcpConn = new(dns.Conn) - connInfo.tcpConn.Conn = conn - connInfo.tcpConn.RemoteAddr = &net.TCPAddr{IP: nameServer.IP, Port: int(nameServer.Port)} } } if r.dnsOverHTTPSEnabled { @@ -534,6 +530,18 @@ func (r *Resolver) getConnectionInfo(nameServer *NameServer) (*ConnectionInfo, e return connInfo, nil } +func getNewTCPConn(nameServer *NameServer, connInfo *ConnectionInfo) error { + // create persistent TCP connection to nameserver + conn, err := net.DialTCP("tcp", &net.TCPAddr{IP: connInfo.localAddr}, &net.TCPAddr{IP: nameServer.IP, Port: int(nameServer.Port)}) + if err != nil { + return fmt.Errorf("unable to create TCP connection for nameserver %s: %w", nameServer.String(), err) + } + connInfo.tcpConn = new(dns.Conn) + connInfo.tcpConn.Conn = conn + connInfo.tcpConn.RemoteAddr = &net.TCPAddr{IP: nameServer.IP, Port: int(nameServer.Port)} + return nil +} + // ExternalLookup performs a single lookup of a DNS question, q, against an external name server. // dstServer, (ex: '1.1.1.1:53') can be set to over-ride the nameservers defined in the ResolverConfig. // If dstServer is not specified (ie. is an empty string), a random external name server will be used from the resolver's list of external name servers. From d114668885f55bb497f36662624866f32574eb00 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 15:44:11 -0400 Subject: [PATCH 09/24] don't loop in retrying tcp connection --- src/zdns/lookup.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index 3f288726..32f7c498 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -456,10 +456,10 @@ func (r *Resolver) retryingLookup(ctx context.Context, q Question, nameServer *N result, status, err = wireLookupUDP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) if status == StatusTruncated && connInfo.tcpClient != nil { // result truncated, try again with TCP - result, status, err = wireLookupTCP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) + result, status, err = wireLookupTCP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit, true) } } else if connInfo.tcpClient != nil { - result, status, err = wireLookupTCP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) + result, status, err = wireLookupTCP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit, true) } else { return SingleQueryResult{}, StatusError, 0, errors.New("no connection info for nameserver") } @@ -622,7 +622,7 @@ func doDoHLookup(ctx context.Context, httpClient *http.Client, q Question, nameS } // wireLookupTCP performs a DNS lookup on-the-wire over TCP with the given parameters -func wireLookupTCP(ctx context.Context, connInfo *ConnectionInfo, q Question, nameServer *NameServer, ednsOptions []dns.EDNS0, recursive, dnssec, checkingDisabled bool) (SingleQueryResult, Status, error) { +func wireLookupTCP(ctx context.Context, connInfo *ConnectionInfo, q Question, nameServer *NameServer, ednsOptions []dns.EDNS0, recursive, dnssec, checkingDisabled, retryOnConnClosing bool) (SingleQueryResult, Status, error) { res := SingleQueryResult{Answers: []interface{}{}, Authorities: []interface{}{}, Additional: []interface{}{}} res.Resolver = nameServer.String() @@ -648,13 +648,13 @@ func wireLookupTCP(ctx context.Context, connInfo *ConnectionInfo, q Question, na return SingleQueryResult{}, StatusError, fmt.Errorf("could not resolve TCP address %s: %v", nameServer.String(), err) } r, _, err = connInfo.tcpClient.ExchangeWithConnToContext(ctx, m, connInfo.tcpConn, addr) - if err != nil && err.Error() == "EOF" { + if retryOnConnClosing && err != nil && err.Error() == "EOF" { // EOF error means the connection was closed, we'll re-open a connection and re-handshake err = getNewTCPConn(nameServer, connInfo) if err != nil { return SingleQueryResult{}, StatusError, fmt.Errorf("could not get new TCP connection to nameserver %s: %v", nameServer.String(), err) } - return wireLookupTCP(ctx, connInfo, q, nameServer, ednsOptions, recursive, dnssec, checkingDisabled) + return wireLookupTCP(ctx, connInfo, q, nameServer, ednsOptions, recursive, dnssec, checkingDisabled, false) } } else { // no pre-existing connection, create a ephemeral one From afe5781cb493a404b6d275f31ec58684294bfee0 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 15:52:42 -0400 Subject: [PATCH 10/24] spelling --- src/zdns/lookup.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index 32f7c498..c7596ff4 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -657,7 +657,7 @@ func wireLookupTCP(ctx context.Context, connInfo *ConnectionInfo, q Question, na return wireLookupTCP(ctx, connInfo, q, nameServer, ednsOptions, recursive, dnssec, checkingDisabled, false) } } else { - // no pre-existing connection, create a ephemeral one + // no pre-existing connection, create an ephemeral one res.Protocol = "tcp" r, _, err = connInfo.tcpClient.ExchangeContext(ctx, m, nameServer.String()) } From e3aa7c8f2f185a5d778c8e120dc51c43de5d41bc Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 16:22:39 -0400 Subject: [PATCH 11/24] close TCP conns in Close() --- src/zdns/resolver.go | 58 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/src/zdns/resolver.go b/src/zdns/resolver.go index 4f751811..129e1449 100644 --- a/src/zdns/resolver.go +++ b/src/zdns/resolver.go @@ -531,6 +531,12 @@ func (r *Resolver) getConnectionInfo(nameServer *NameServer) (*ConnectionInfo, e } func getNewTCPConn(nameServer *NameServer, connInfo *ConnectionInfo) error { + // close any existing TCP connection + if connInfo.tcpConn != nil { + if err := connInfo.tcpConn.Close(); err != nil { + return fmt.Errorf("error closing existing TCP connection: %w", err) + } + } // create persistent TCP connection to nameserver conn, err := net.DialTCP("tcp", &net.TCPAddr{IP: connInfo.localAddr}, &net.TCPAddr{IP: nameServer.IP, Port: int(nameServer.Port)}) if err != nil { @@ -583,24 +589,52 @@ func (r *Resolver) IterativeLookup(q *Question) (*SingleQueryResult, Trace, Stat // Close cleans up any resources used by the resolver. This should be called when the resolver is no longer needed. // Lookup will panic if called after Close. func (r *Resolver) Close() { - if r.connInfoIPv4Internet != nil && r.connInfoIPv4Internet.udpConn != nil { - if err := r.connInfoIPv4Internet.udpConn.Close(); err != nil { - log.Errorf("error closing IPv4 connection: %v", err) + if r.connInfoIPv4Internet != nil { + if r.connInfoIPv4Internet.udpConn != nil { + if err := r.connInfoIPv4Internet.udpConn.Close(); err != nil { + log.Errorf("error closing UDP IPv4 connection: %v", err) + } + } + if r.connInfoIPv4Internet.tcpConn != nil { + if err := r.connInfoIPv4Internet.tcpConn.Close(); err != nil { + log.Errorf("error closing TCP IPv4 connection: %v", err) + } } } - if r.connInfoIPv6Internet != nil && r.connInfoIPv6Internet.udpConn != nil { - if err := r.connInfoIPv6Internet.udpConn.Close(); err != nil { - log.Errorf("error closing IPv6 connection: %v", err) + if r.connInfoIPv6Internet != nil { + if r.connInfoIPv6Internet.udpConn != nil { + if err := r.connInfoIPv6Internet.udpConn.Close(); err != nil { + log.Errorf("error closing UDP IPv6 connection: %v", err) + } + } + if r.connInfoIPv6Internet.tcpConn != nil { + if err := r.connInfoIPv6Internet.tcpConn.Close(); err != nil { + log.Errorf("error closing TCP IPv6 connection: %v", err) + } } } - if r.connInfoIPv4Loopback != nil && r.connInfoIPv4Loopback.udpConn != nil { - if err := r.connInfoIPv4Loopback.udpConn.Close(); err != nil { - log.Errorf("error closing IPv4 loopback connection: %v", err) + if r.connInfoIPv4Loopback != nil { + if r.connInfoIPv4Loopback.udpConn != nil { + if err := r.connInfoIPv4Loopback.udpConn.Close(); err != nil { + log.Errorf("error closing IPv4 UDP loopback connection: %v", err) + } + } + if r.connInfoIPv4Loopback.tcpConn != nil { + if err := r.connInfoIPv4Loopback.tcpConn.Close(); err != nil { + log.Errorf("error closing IPv4 TCP loopback connection: %v", err) + } } } - if r.connInfoIPv6Loopback != nil && r.connInfoIPv6Loopback.udpConn != nil { - if err := r.connInfoIPv6Loopback.udpConn.Close(); err != nil { - log.Errorf("error closing IPv6 loopback connection: %v", err) + if r.connInfoIPv6Loopback != nil { + if r.connInfoIPv6Loopback.udpConn != nil { + if err := r.connInfoIPv6Loopback.udpConn.Close(); err != nil { + log.Errorf("error closing IPv6 UDP loopback connection: %v", err) + } + } + if r.connInfoIPv6Loopback.tcpConn != nil { + if err := r.connInfoIPv6Loopback.tcpConn.Close(); err != nil { + log.Errorf("error closing IPv6 TCP loopback connection: %v", err) + } } } } From 2cb78779ce52bc8428a732db87fba7f21691f609 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 16:31:03 -0400 Subject: [PATCH 12/24] trying multiple de-multiplexors --- src/cli/worker_manager.go | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index beb1b9a6..5701b46b 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -458,7 +458,6 @@ type InputLineWithNameServer struct { // The goal is that a query for a single name server will consistently go to the same worker pool which 1+ threads will read from // This is especially useful for TLS/TCP/HTTPS based lookups where repeating the initial handshakes would be wasteful func inputDeMultiplexer(nameServers []zdns.NameServer, inChan <-chan string, workerPools *WorkerPools, wg *sync.WaitGroup) error { - wg.Add(1) defer wg.Done() // defer closing the worker pool chans defer func() { @@ -553,12 +552,16 @@ func Run(gc CLIConf) { log.Fatal(fmt.Sprintf("could not feed input channel: %v", inErr)) } }() - go func() { - plexErr := inputDeMultiplexer(uniqNameServers, inChan, workerPools, &routineWG) - if plexErr != nil { - log.Fatal(fmt.Sprintf("could not de-multiplex input channel: %v", plexErr)) - } - }() + const numberOfDeMultiplexers = 5 + for i := 0; i < numberOfDeMultiplexers; i++ { + go func() { + plexErr := inputDeMultiplexer(uniqNameServers, inChan, workerPools, &routineWG) + if plexErr != nil { + log.Fatal(fmt.Sprintf("could not de-multiplex input channel: %v", plexErr)) + } + }() + } + routineWG.Add(numberOfDeMultiplexers) go func() { outErr := outHandler.WriteResults(outChan, &routineWG) if outErr != nil { From 58b790aabec7de0dd1a0981f62b3fc8bae518cb6 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 16:36:19 -0400 Subject: [PATCH 13/24] Revert "trying multiple de-multiplexors" This reverts commit 2cb78779ce52bc8428a732db87fba7f21691f609. --- src/cli/worker_manager.go | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index 5701b46b..beb1b9a6 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -458,6 +458,7 @@ type InputLineWithNameServer struct { // The goal is that a query for a single name server will consistently go to the same worker pool which 1+ threads will read from // This is especially useful for TLS/TCP/HTTPS based lookups where repeating the initial handshakes would be wasteful func inputDeMultiplexer(nameServers []zdns.NameServer, inChan <-chan string, workerPools *WorkerPools, wg *sync.WaitGroup) error { + wg.Add(1) defer wg.Done() // defer closing the worker pool chans defer func() { @@ -552,16 +553,12 @@ func Run(gc CLIConf) { log.Fatal(fmt.Sprintf("could not feed input channel: %v", inErr)) } }() - const numberOfDeMultiplexers = 5 - for i := 0; i < numberOfDeMultiplexers; i++ { - go func() { - plexErr := inputDeMultiplexer(uniqNameServers, inChan, workerPools, &routineWG) - if plexErr != nil { - log.Fatal(fmt.Sprintf("could not de-multiplex input channel: %v", plexErr)) - } - }() - } - routineWG.Add(numberOfDeMultiplexers) + go func() { + plexErr := inputDeMultiplexer(uniqNameServers, inChan, workerPools, &routineWG) + if plexErr != nil { + log.Fatal(fmt.Sprintf("could not de-multiplex input channel: %v", plexErr)) + } + }() go func() { outErr := outHandler.WriteResults(outChan, &routineWG) if outErr != nil { From efb3d5d3f4a4dc7e606fa21af346cf1edcffca1a Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 16:37:51 -0400 Subject: [PATCH 14/24] TEST - check how long non-network activity takes --- src/zdns/lookup.go | 1 + 1 file changed, 1 insertion(+) diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index c7596ff4..11b1e886 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -472,6 +472,7 @@ func (r *Resolver) retryingLookup(ctx context.Context, q Question, nameServer *N } func doDoTLookup(ctx context.Context, connInfo *ConnectionInfo, q Question, nameServer *NameServer, recursive bool, ednsOptions []dns.EDNS0, dnssec bool, checkingDisabled bool) (SingleQueryResult, Status, error) { + return SingleQueryResult{}, StatusError, errors.New("DoT not implemented") m := new(dns.Msg) m.SetQuestion(dotName(q.Name), q.Type) m.Question[0].Qclass = q.Class From fbee6de420c7a63c1df2de73c77b24cef31c5c1c Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 16:41:08 -0400 Subject: [PATCH 15/24] TEST - :( --- src/zdns/lookup.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index 11b1e886..ad7a5c0d 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -472,7 +472,6 @@ func (r *Resolver) retryingLookup(ctx context.Context, q Question, nameServer *N } func doDoTLookup(ctx context.Context, connInfo *ConnectionInfo, q Question, nameServer *NameServer, recursive bool, ednsOptions []dns.EDNS0, dnssec bool, checkingDisabled bool) (SingleQueryResult, Status, error) { - return SingleQueryResult{}, StatusError, errors.New("DoT not implemented") m := new(dns.Msg) m.SetQuestion(dotName(q.Name), q.Type) m.Question[0].Qclass = q.Class @@ -624,6 +623,7 @@ func doDoHLookup(ctx context.Context, httpClient *http.Client, q Question, nameS // wireLookupTCP performs a DNS lookup on-the-wire over TCP with the given parameters func wireLookupTCP(ctx context.Context, connInfo *ConnectionInfo, q Question, nameServer *NameServer, ednsOptions []dns.EDNS0, recursive, dnssec, checkingDisabled, retryOnConnClosing bool) (SingleQueryResult, Status, error) { + return SingleQueryResult{}, StatusError, errors.New("TCP not implemented") res := SingleQueryResult{Answers: []interface{}{}, Authorities: []interface{}{}, Additional: []interface{}{}} res.Resolver = nameServer.String() From e7e70053b5329423ce212b284c7d186b7d5737ce Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 17:15:50 -0400 Subject: [PATCH 16/24] removed testing line --- src/zdns/lookup.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index ad7a5c0d..2bb8d121 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -16,17 +16,16 @@ package zdns import ( "context" "fmt" - "io" - "net" - "regexp" - "strings" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/zmap/dns" "github.com/zmap/zcrypto/tls" "github.com/zmap/zgrab2/lib/http" "github.com/zmap/zgrab2/lib/output" + "io" + "net" + "regexp" + "strings" "github.com/zmap/zdns/src/internal/util" ) @@ -623,7 +622,6 @@ func doDoHLookup(ctx context.Context, httpClient *http.Client, q Question, nameS // wireLookupTCP performs a DNS lookup on-the-wire over TCP with the given parameters func wireLookupTCP(ctx context.Context, connInfo *ConnectionInfo, q Question, nameServer *NameServer, ednsOptions []dns.EDNS0, recursive, dnssec, checkingDisabled, retryOnConnClosing bool) (SingleQueryResult, Status, error) { - return SingleQueryResult{}, StatusError, errors.New("TCP not implemented") res := SingleQueryResult{Answers: []interface{}{}, Authorities: []interface{}{}, Additional: []interface{}{}} res.Resolver = nameServer.String() From 0249fa1de5f389ac25a39d59c7c244ebcfa5ad1d Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Tue, 10 Sep 2024 17:18:57 -0400 Subject: [PATCH 17/24] trying giving the pool channels a capacity --- src/cli/worker_manager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index beb1b9a6..ef29f6db 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -441,7 +441,7 @@ type WorkerPools struct { func NewWorkerPools(numPools int) *WorkerPools { workerPools := make([]chan *InputLineWithNameServer, numPools) for i := 0; i < numPools; i++ { - workerPools[i] = make(chan *InputLineWithNameServer) + workerPools[i] = make(chan *InputLineWithNameServer, 10) } return &WorkerPools{WorkerPools: workerPools} } From 342fd090b237e7f1dfeb477a3c9713eb8a5b7c97 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Wed, 11 Sep 2024 11:30:36 -0400 Subject: [PATCH 18/24] implement work-balancing scheme --- src/cli/worker_manager.go | 180 ++++++++++++++++++++++++-------------- 1 file changed, 112 insertions(+), 68 deletions(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index ef29f6db..d0eb817c 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -435,15 +435,16 @@ func populateLocalAddresses(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Res // WorkerPools are a collection of channels that workers can read from // 1+ threads will read from a pooled channel, and the inputDeMultiplexer will send input to the appropriate channel type WorkerPools struct { - WorkerPools []chan *InputLineWithNameServer + WorkerPools []chan *InputLineWithNameServer + GlobalTaskPool chan *InputLineWithNameServer } func NewWorkerPools(numPools int) *WorkerPools { workerPools := make([]chan *InputLineWithNameServer, numPools) for i := 0; i < numPools; i++ { - workerPools[i] = make(chan *InputLineWithNameServer, 10) + workerPools[i] = make(chan *InputLineWithNameServer, 1) } - return &WorkerPools{WorkerPools: workerPools} + return &WorkerPools{WorkerPools: workerPools, GlobalTaskPool: make(chan *InputLineWithNameServer)} } // InputLineWithNameServer is a struct that contains a line of input and the name server to use for the lookup @@ -457,20 +458,33 @@ type InputLineWithNameServer struct { // inputDeMultiplxer is a single goroutine that reads from the input channel and sends the input to the appropriate worker pool channel // The goal is that a query for a single name server will consistently go to the same worker pool which 1+ threads will read from // This is especially useful for TLS/TCP/HTTPS based lookups where repeating the initial handshakes would be wasteful +// Work Balancing +// The GlobalTaskPool is used to address work imbalance between worker pools. If a query should go to Pool A but Pool A is busy, it will go to the GlobalTaskPool +// Workers will check the GlobalTaskPool only if their pool is empty. This means they will tend to re-use their connections, but help out other pools if they're idle func inputDeMultiplexer(nameServers []zdns.NameServer, inChan <-chan string, workerPools *WorkerPools, wg *sync.WaitGroup) error { - wg.Add(1) defer wg.Done() // defer closing the worker pool chans defer func() { for _, pool := range workerPools.WorkerPools { close(pool) } + close(workerPools.GlobalTaskPool) }() for line := range inChan { nsIndex := rand.Intn(len(nameServers)) randomNS := nameServers[nsIndex] chanID := nsIndex % len(workerPools.WorkerPools) - workerPools.WorkerPools[chanID] <- &InputLineWithNameServer{Line: line, NameServer: &randomNS} + work := &InputLineWithNameServer{Line: line, NameServer: &randomNS} + // for each work item, we prefer to send it to the assigned worker pool for the name server. If that pool is busy, we'll send it to the global task pool + select { + case workerPools.WorkerPools[chanID] <- work: // prefer to send to the worker pool for the name server + default: + // worker pool is busy, we'll take first available spot between the global task pool and the worker pool + select { + case workerPools.GlobalTaskPool <- work: + case workerPools.WorkerPools[chanID] <- work: + } + } } return nil } @@ -565,7 +579,7 @@ func Run(gc CLIConf) { log.Fatal(fmt.Sprintf("could not write output results from output channel: %v", outErr)) } }() - routineWG.Add(2) + routineWG.Add(3) // create pool of worker goroutines var lookupWG sync.WaitGroup @@ -577,7 +591,7 @@ func Run(gc CLIConf) { // assign each worker to a worker pool, we'll loop around if we have more workers than pools channelID := i % len(workerPools.WorkerPools) go func(threadID int) { - initWorkerErr := doLookupWorker(&gc, resolverConfig, workerPools.WorkerPools[channelID], outChan, metaChan, &lookupWG) + initWorkerErr := doLookupWorker(&gc, resolverConfig, workerPools.WorkerPools[channelID], workerPools.GlobalTaskPool, outChan, metaChan, &lookupWG) if initWorkerErr != nil { log.Fatalf("could not start lookup worker #%d: %v", i, initWorkerErr) } @@ -628,7 +642,7 @@ func Run(gc CLIConf) { } // doLookupWorker is a single worker thread that processes lookups from the input channel. It calls wg.Done when it is finished. -func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, input <-chan *InputLineWithNameServer, output chan<- string, metaChan chan<- routineMetadata, wg *sync.WaitGroup) error { +func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, preferredWorkChan, globalWorkChan <-chan *InputLineWithNameServer, output chan<- string, metaChan chan<- routineMetadata, wg *sync.WaitGroup) error { defer wg.Done() resolver, err := zdns.InitResolver(rc) if err != nil { @@ -636,70 +650,34 @@ func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, input <-chan *InputLin } var metadata routineMetadata metadata.Status = make(map[zdns.Status]int) - for line := range input { - // we'll process each module sequentially, parallelism is per-domain - res, rawName, nameServer := parseInputLine(gc, rc, line) - res.Name = rawName - // handle per-module lookups - for moduleName, module := range gc.ActiveModules { - if moduleName == "AXFR" { - // special case, AXFR has its own nameserver handling. We'll only take nameservers if the user provides it - // not the "suggestion" from the de-multiplexor - if nameServer.String() == line.NameServer.String() { - // this name server is the suggested one from the de-multiplexor, we'll remove it - nameServer = nil - } - } - var innerRes interface{} - var trace zdns.Trace - var status zdns.Status - var err error - var changed bool - var lookupName string - lookupName, changed = makeName(rawName, gc.NamePrefix, gc.NameOverride) - if changed { - res.AlteredName = lookupName + var task *InputLineWithNameServer + var ok bool + +WorkerLoop: + for { + // Check its own resolver tasks first (reusing connections) + select { + case task, ok = <-preferredWorkChan: + if !ok { + // inputDeMultiplexer has closed the channel, we're done + break WorkerLoop } - res.Class = dns.Class(gc.Class).String() - - startTime := time.Now() - innerRes, trace, status, err = module.Lookup(resolver, lookupName, nameServer) - - lookupRes := zdns.SingleModuleResult{ - Timestamp: time.Now().Format(gc.TimeFormat), - Duration: time.Since(startTime).Seconds(), - } - if status != zdns.StatusNoOutput { - lookupRes.Status = string(status) - lookupRes.Data = innerRes - lookupRes.Trace = trace - if err != nil { - lookupRes.Error = err.Error() + handleWorkerInput(gc, rc, task, resolver, &metadata, output) + default: + // No tasks in its own resolver, wait on either + select { + case task, ok = <-preferredWorkChan: + if !ok { + break WorkerLoop } - res.Results[moduleName] = lookupRes - } - metadata.Status[status]++ - metadata.Lookups++ - } - if len(res.Results) > 0 { - v, _ := version.NewVersion("0.0.0") - o := &sheriff.Options{ - Groups: gc.OutputGroups, - ApiVersion: v, - IncludeEmptyTag: true, - } - data, err := sheriff.Marshal(o, res) - if err != nil { - log.Fatalf("unable to marshal result to JSON: %v", err) - } - cleansedData := replaceIntSliceInterface(data) - jsonRes, err := json.Marshal(cleansedData) - if err != nil { - log.Fatalf("unable to marshal JSON result: %v", err) + handleWorkerInput(gc, rc, task, resolver, &metadata, output) + case task, ok = <-globalWorkChan: + if !ok { + break WorkerLoop + } + handleWorkerInput(gc, rc, task, resolver, &metadata, output) } - output <- string(jsonRes) } - metadata.Names++ } // close the resolver, freeing up resources resolver.Close() @@ -707,6 +685,72 @@ func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, input <-chan *InputLin return nil } +func handleWorkerInput(gc *CLIConf, rc *zdns.ResolverConfig, line *InputLineWithNameServer, resolver *zdns.Resolver, metadata *routineMetadata, output chan<- string) { + // we'll process each module sequentially, parallelism is per-domain + res, rawName, nameServer := parseInputLine(gc, rc, line) + res.Name = rawName + // handle per-module lookups + for moduleName, module := range gc.ActiveModules { + if moduleName == "AXFR" { + // special case, AXFR has its own nameserver handling. We'll only take nameservers if the user provides it + // not the "suggestion" from the de-multiplexor + if nameServer.String() == line.NameServer.String() { + // this name server is the suggested one from the de-multiplexor, we'll remove it + nameServer = nil + } + } + var innerRes interface{} + var trace zdns.Trace + var status zdns.Status + var err error + var changed bool + var lookupName string + lookupName, changed = makeName(rawName, gc.NamePrefix, gc.NameOverride) + if changed { + res.AlteredName = lookupName + } + res.Class = dns.Class(gc.Class).String() + + startTime := time.Now() + innerRes, trace, status, err = module.Lookup(resolver, lookupName, nameServer) + + lookupRes := zdns.SingleModuleResult{ + Timestamp: time.Now().Format(gc.TimeFormat), + Duration: time.Since(startTime).Seconds(), + } + if status != zdns.StatusNoOutput { + lookupRes.Status = string(status) + lookupRes.Data = innerRes + lookupRes.Trace = trace + if err != nil { + lookupRes.Error = err.Error() + } + res.Results[moduleName] = lookupRes + } + metadata.Status[status]++ + metadata.Lookups++ + } + if len(res.Results) > 0 { + v, _ := version.NewVersion("0.0.0") + o := &sheriff.Options{ + Groups: gc.OutputGroups, + ApiVersion: v, + IncludeEmptyTag: true, + } + data, err := sheriff.Marshal(o, res) + if err != nil { + log.Fatalf("unable to marshal result to JSON: %v", err) + } + cleansedData := replaceIntSliceInterface(data) + jsonRes, err := json.Marshal(cleansedData) + if err != nil { + log.Fatalf("unable to marshal JSON result: %v", err) + } + output <- string(jsonRes) + } + metadata.Names++ +} + func parseInputLine(gc *CLIConf, rc *zdns.ResolverConfig, line *InputLineWithNameServer) (*zdns.Result, string, *zdns.NameServer) { res := zdns.Result{Results: make(map[string]zdns.SingleModuleResult, len(gc.ActiveModules))} // get the fields that won't change for each lookup module From b439c0bd815a991843e37e1d49cf4fac0ffcfae0 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Wed, 11 Sep 2024 11:55:40 -0400 Subject: [PATCH 19/24] added small wait before going to global queue --- src/cli/worker_manager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index d0eb817c..5fcbe3e2 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -663,7 +663,7 @@ WorkerLoop: break WorkerLoop } handleWorkerInput(gc, rc, task, resolver, &metadata, output) - default: + case <-time.After(time.Millisecond * 10): // No tasks in its own resolver, wait on either select { case task, ok = <-preferredWorkChan: From f83cf19671d5e91b76d58ff20e2c7471cc5eeddf Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Wed, 11 Sep 2024 14:03:51 -0400 Subject: [PATCH 20/24] fix errors if destination closes the TCP connection --- src/cli/worker_manager.go | 2 +- src/zdns/lookup.go | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index 5fcbe3e2..d0eb817c 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -663,7 +663,7 @@ WorkerLoop: break WorkerLoop } handleWorkerInput(gc, rc, task, resolver, &metadata, output) - case <-time.After(time.Millisecond * 10): + default: // No tasks in its own resolver, wait on either select { case task, ok = <-preferredWorkChan: diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index 2bb8d121..0bb945a1 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -648,12 +648,14 @@ func wireLookupTCP(ctx context.Context, connInfo *ConnectionInfo, q Question, na } r, _, err = connInfo.tcpClient.ExchangeWithConnToContext(ctx, m, connInfo.tcpConn, addr) if retryOnConnClosing && err != nil && err.Error() == "EOF" { - // EOF error means the connection was closed, we'll re-open a connection and re-handshake - err = getNewTCPConn(nameServer, connInfo) + // EOF error means the connection was closed, we'll remove the connection (it'll be recreated on the next iteration) + // and try again + err = connInfo.tcpConn.Conn.Close() if err != nil { - return SingleQueryResult{}, StatusError, fmt.Errorf("could not get new TCP connection to nameserver %s: %v", nameServer.String(), err) + log.Errorf("error closing TCP connection: %v", err) } - return wireLookupTCP(ctx, connInfo, q, nameServer, ednsOptions, recursive, dnssec, checkingDisabled, false) + connInfo.tcpConn = nil + r, _, err = connInfo.tcpClient.ExchangeContext(ctx, m, nameServer.String()) } } else { // no pre-existing connection, create an ephemeral one From 6086ce35194dc772a63ca5eca37dd7f537a8e89a Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Wed, 11 Sep 2024 14:24:59 -0400 Subject: [PATCH 21/24] lint --- src/zdns/lookup.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index 0bb945a1..1c800a94 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -16,16 +16,17 @@ package zdns import ( "context" "fmt" + "io" + "net" + "regexp" + "strings" + "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/zmap/dns" "github.com/zmap/zcrypto/tls" "github.com/zmap/zgrab2/lib/http" "github.com/zmap/zgrab2/lib/output" - "io" - "net" - "regexp" - "strings" "github.com/zmap/zdns/src/internal/util" ) From 172d1cc9e6cbb383b3c0f8b9e64570f11004a054 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Wed, 11 Sep 2024 15:33:44 -0400 Subject: [PATCH 22/24] refactor - coalesce language around worker channels --- src/cli/worker_manager.go | 69 ++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index d0eb817c..a8ccd2fb 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -432,57 +432,58 @@ func populateLocalAddresses(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Res return config, nil } -// WorkerPools are a collection of channels that workers can read from -// 1+ threads will read from a pooled channel, and the inputDeMultiplexer will send input to the appropriate channel -type WorkerPools struct { - WorkerPools []chan *InputLineWithNameServer - GlobalTaskPool chan *InputLineWithNameServer +// WorkChans are a collection of channels that workers can read from +// Worker threads are assigned a PriorityWorkChan, which is a channel where all queries are directed at a single nameserver +// If a thread is idle, it will read from the GlobalWorkChan, helping to relieve work imbalance between worker nameservers +type WorkChans struct { + PriorityWorkChans []chan *InputLineWithNameServer + GlobalWorkChan chan *InputLineWithNameServer } -func NewWorkerPools(numPools int) *WorkerPools { - workerPools := make([]chan *InputLineWithNameServer, numPools) - for i := 0; i < numPools; i++ { +// NewWorkerChans creates numChans priority worker channels and a global worker channel +func NewWorkerChans(numPriorityChans int) *WorkChans { + workerPools := make([]chan *InputLineWithNameServer, numPriorityChans) + for i := 0; i < numPriorityChans; i++ { workerPools[i] = make(chan *InputLineWithNameServer, 1) } - return &WorkerPools{WorkerPools: workerPools, GlobalTaskPool: make(chan *InputLineWithNameServer)} + return &WorkChans{PriorityWorkChans: workerPools, GlobalWorkChan: make(chan *InputLineWithNameServer)} } // InputLineWithNameServer is a struct that contains a line of input and the name server to use for the lookup // This name server is a "suggestion", --iterative lookups will ignore it as well as AXFR lookups -// The goal is to attempt to send all queries for a single name server to the same worker pool type InputLineWithNameServer struct { Line string NameServer *zdns.NameServer } -// inputDeMultiplxer is a single goroutine that reads from the input channel and sends the input to the appropriate worker pool channel -// The goal is that a query for a single name server will consistently go to the same worker pool which 1+ threads will read from -// This is especially useful for TLS/TCP/HTTPS based lookups where repeating the initial handshakes would be wasteful +// inputDeMultiplexer is a single goroutine that reads from the input channel and prioritizes sending work to it's respective +// prioritized input channel. If the priority channel is full, it will send the work to the global work channel for an idle thread +// to load balance. The goal is that a worker thread will tend to re-use their existing TCP/TLS/HTTPS connection, saving handshakes. // Work Balancing -// The GlobalTaskPool is used to address work imbalance between worker pools. If a query should go to Pool A but Pool A is busy, it will go to the GlobalTaskPool -// Workers will check the GlobalTaskPool only if their pool is empty. This means they will tend to re-use their connections, but help out other pools if they're idle -func inputDeMultiplexer(nameServers []zdns.NameServer, inChan <-chan string, workerPools *WorkerPools, wg *sync.WaitGroup) error { +// The GlobalWorkChan is used to address work imbalance between worker pools. If a query should go to Priority Channel A but A is busy, it will go to the GlobalWorkChan +// Workers will check the GlobalWorkChan only if their Priority channel is empty. This means they will tend to re-use their connections, but help out other pools if they're idle +func inputDeMultiplexer(nameServers []zdns.NameServer, inChan <-chan string, workerPools *WorkChans, wg *sync.WaitGroup) error { defer wg.Done() - // defer closing the worker pool chans defer func() { - for _, pool := range workerPools.WorkerPools { + // cleanup work channels + for _, pool := range workerPools.PriorityWorkChans { close(pool) } - close(workerPools.GlobalTaskPool) + close(workerPools.GlobalWorkChan) }() for line := range inChan { nsIndex := rand.Intn(len(nameServers)) randomNS := nameServers[nsIndex] - chanID := nsIndex % len(workerPools.WorkerPools) + chanID := nsIndex % len(workerPools.PriorityWorkChans) work := &InputLineWithNameServer{Line: line, NameServer: &randomNS} - // for each work item, we prefer to send it to the assigned worker pool for the name server. If that pool is busy, we'll send it to the global task pool + // for each work item, we prefer to send it to the assigned worker pool for the name server select { - case workerPools.WorkerPools[chanID] <- work: // prefer to send to the worker pool for the name server + case workerPools.PriorityWorkChans[chanID] <- work: // prefer to send to the worker pool for the name server default: - // worker pool is busy, we'll take first available spot between the global task pool and the worker pool + // worker pool is busy, we'll take first available spot between the global and priority channels select { - case workerPools.GlobalTaskPool <- work: - case workerPools.WorkerPools[chanID] <- work: + case workerPools.GlobalWorkChan <- work: + case workerPools.PriorityWorkChans[chanID] <- work: } } } @@ -553,12 +554,12 @@ func Run(gc CLIConf) { } uniqNameServers = uniqueDomainNSes } - numberOfWorkerPools := len(uniqNameServers) - if gc.Threads < numberOfWorkerPools { + numberOfPriorityChans := len(uniqNameServers) + if gc.Threads < numberOfPriorityChans { // multiple threads can share a channel, but we can't have more channels than threads - numberOfWorkerPools = gc.Threads + numberOfPriorityChans = gc.Threads } - workerPools := NewWorkerPools(numberOfWorkerPools) + workerPools := NewWorkerChans(numberOfPriorityChans) // Use handlers to populate the input and output/results channel go func() { @@ -588,10 +589,10 @@ func Run(gc CLIConf) { // create shared cache for all threads to share for i := 0; i < gc.Threads; i++ { i := i - // assign each worker to a worker pool, we'll loop around if we have more workers than pools - channelID := i % len(workerPools.WorkerPools) + // assign each worker to a priority channel, we'll loop around if we have more workers than channels + channelID := i % len(workerPools.PriorityWorkChans) go func(threadID int) { - initWorkerErr := doLookupWorker(&gc, resolverConfig, workerPools.WorkerPools[channelID], workerPools.GlobalTaskPool, outChan, metaChan, &lookupWG) + initWorkerErr := doLookupWorker(&gc, resolverConfig, workerPools.PriorityWorkChans[channelID], workerPools.GlobalWorkChan, outChan, metaChan, &lookupWG) if initWorkerErr != nil { log.Fatalf("could not start lookup worker #%d: %v", i, initWorkerErr) } @@ -655,7 +656,7 @@ func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, preferredWorkChan, glo WorkerLoop: for { - // Check its own resolver tasks first (reusing connections) + // Check its own priority channel first to prioritize re-using TCP/HTTPS/TLS connections select { case task, ok = <-preferredWorkChan: if !ok { @@ -664,7 +665,7 @@ WorkerLoop: } handleWorkerInput(gc, rc, task, resolver, &metadata, output) default: - // No tasks in its own resolver, wait on either + // wait on either Priority/Global channel select { case task, ok = <-preferredWorkChan: if !ok { From 99329eb08d418f9bbb7c4f6ef2e409fd0759cd39 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Wed, 11 Sep 2024 15:44:40 -0400 Subject: [PATCH 23/24] removed the shouldRetryIfConnClosed bool, didn't add anything --- src/zdns/lookup.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index 1c800a94..0629727a 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -456,10 +456,10 @@ func (r *Resolver) retryingLookup(ctx context.Context, q Question, nameServer *N result, status, err = wireLookupUDP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) if status == StatusTruncated && connInfo.tcpClient != nil { // result truncated, try again with TCP - result, status, err = wireLookupTCP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit, true) + result, status, err = wireLookupTCP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) } } else if connInfo.tcpClient != nil { - result, status, err = wireLookupTCP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit, true) + result, status, err = wireLookupTCP(ctx, connInfo, q, nameServer, r.ednsOptions, recursive, r.dnsSecEnabled, r.checkingDisabledBit) } else { return SingleQueryResult{}, StatusError, 0, errors.New("no connection info for nameserver") } @@ -622,7 +622,7 @@ func doDoHLookup(ctx context.Context, httpClient *http.Client, q Question, nameS } // wireLookupTCP performs a DNS lookup on-the-wire over TCP with the given parameters -func wireLookupTCP(ctx context.Context, connInfo *ConnectionInfo, q Question, nameServer *NameServer, ednsOptions []dns.EDNS0, recursive, dnssec, checkingDisabled, retryOnConnClosing bool) (SingleQueryResult, Status, error) { +func wireLookupTCP(ctx context.Context, connInfo *ConnectionInfo, q Question, nameServer *NameServer, ednsOptions []dns.EDNS0, recursive, dnssec, checkingDisabled bool) (SingleQueryResult, Status, error) { res := SingleQueryResult{Answers: []interface{}{}, Authorities: []interface{}{}, Additional: []interface{}{}} res.Resolver = nameServer.String() @@ -648,7 +648,7 @@ func wireLookupTCP(ctx context.Context, connInfo *ConnectionInfo, q Question, na return SingleQueryResult{}, StatusError, fmt.Errorf("could not resolve TCP address %s: %v", nameServer.String(), err) } r, _, err = connInfo.tcpClient.ExchangeWithConnToContext(ctx, m, connInfo.tcpConn, addr) - if retryOnConnClosing && err != nil && err.Error() == "EOF" { + if err != nil && err.Error() == "EOF" { // EOF error means the connection was closed, we'll remove the connection (it'll be recreated on the next iteration) // and try again err = connInfo.tcpConn.Conn.Close() From 7328f5f5c394f24a99f181decee9b3f4d2df4028 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Wed, 11 Sep 2024 15:52:51 -0400 Subject: [PATCH 24/24] cleanup --- src/zdns/lookup.go | 1 - src/zdns/resolver.go | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index 0629727a..0401aeed 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -700,7 +700,6 @@ func wireLookupUDP(ctx context.Context, connInfo *ConnectionInfo, q Question, na } else { r, _, err = connInfo.udpClient.ExchangeContext(ctx, m, nameServer.String()) } - // if record comes back truncated, but we have a TCP connection, try again with that if r != nil && (r.Truncated || r.Rcode == dns.RcodeBadTrunc) { return res, StatusTruncated, err } diff --git a/src/zdns/resolver.go b/src/zdns/resolver.go index 129e1449..cf270b7e 100644 --- a/src/zdns/resolver.go +++ b/src/zdns/resolver.go @@ -113,11 +113,11 @@ func (rc *ResolverConfig) Validate() error { // External Nameservers if rc.IPVersionMode != IPv6Only && len(rc.ExternalNameServersV4) == 0 { // If IPv4 is supported, we require at least one IPv4 external nameserver - return errors.New("must have at least one external IPv4 name server if IPv4 mode is enabled. Use IPv6 only if you don't have IPv4 nameservers") + return errors.New("must have at least one external IPv4 name server if IPv4 mode is enabled. Use IPv6-only if you don't have IPv4 nameservers") } if rc.IPVersionMode != IPv4Only && len(rc.ExternalNameServersV6) == 0 { // If IPv6 is supported, we require at least one IPv6 external nameserver - return errors.New("must have at least one external IPv6 name server if IPv6 mode is enabled. Use IPv4 only if you don't have IPv6 nameservers") + return errors.New("must have at least one external IPv6 name server if IPv6 mode is enabled. Use IPv4-only if you don't have IPv6 nameservers") } // Validate all nameservers have ports and are valid IPs @@ -234,7 +234,7 @@ type ConnectionInfo struct { udpClient *dns.Client tcpClient *dns.Client udpConn *dns.Conn // for socket re-use with UDP - tcpConn *dns.Conn // for socket re-use with TCP, if RemoteAddr doesn't change, we don't re-handshake + tcpConn *dns.Conn // for socket re-use with TCP httpsClient *http.Client // for DoH tlsConn *dns.Conn // for DoT tlsHandshake *tls.ServerHandshake // for DoT, used to print TLS handshake to user