Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DoH/DoT/TCP-based lookups and connection re-use #439

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8fe1437
working input demultiplexor with tls
phillip-stephens Sep 9, 2024
0957e6f
handled tcp conns
phillip-stephens Sep 10, 2024
d425d40
handle HTTPS de-multiplexing
phillip-stephens Sep 10, 2024
69c5106
lint
phillip-stephens Sep 10, 2024
c45ebc8
improved error msg if user only supplies IPv4 addresses and we fail c…
phillip-stephens Sep 10, 2024
67e8149
added AXFR edge case handling
phillip-stephens Sep 10, 2024
fade130
added comments
phillip-stephens Sep 10, 2024
0372d8d
if TCP connection is closed, re-open it
phillip-stephens Sep 10, 2024
d114668
don't loop in retrying tcp connection
phillip-stephens Sep 10, 2024
afe5781
spelling
phillip-stephens Sep 10, 2024
e3aa7c8
close TCP conns in Close()
phillip-stephens Sep 10, 2024
137ebbb
Merge branch 'phillip/336-dns-over-https' into phillip/336-doh-and-do…
phillip-stephens Sep 10, 2024
2cb7877
trying multiple de-multiplexors
phillip-stephens Sep 10, 2024
58b790a
Revert "trying multiple de-multiplexors"
phillip-stephens Sep 10, 2024
efb3d5d
TEST - check how long non-network activity takes
phillip-stephens Sep 10, 2024
fbee6de
TEST - :(
phillip-stephens Sep 10, 2024
e7e7005
removed testing line
phillip-stephens Sep 10, 2024
0249fa1
trying giving the pool channels a capacity
phillip-stephens Sep 10, 2024
342fd09
implement work-balancing scheme
phillip-stephens Sep 11, 2024
b439c0b
added small wait before going to global queue
phillip-stephens Sep 11, 2024
f83cf19
fix errors if destination closes the TCP connection
phillip-stephens Sep 11, 2024
6086ce3
lint
phillip-stephens Sep 11, 2024
172d1cc
refactor - coalesce language around worker channels
phillip-stephens Sep 11, 2024
99329eb
removed the shouldRetryIfConnClosed bool, didn't add anything
phillip-stephens Sep 11, 2024
7328f5f
cleanup
phillip-stephens Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 235 additions & 87 deletions src/cli/worker_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,64 @@ func populateLocalAddresses(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Res
return config, nil
}

// WorkChans are a collection of channels that workers can read from
// Worker threads are assigned a PriorityWorkChan, which is a channel where all queries are directed at a single nameserver
// If a thread is idle, it will read from the GlobalWorkChan, helping to relieve work imbalance between worker nameservers
type WorkChans struct {
PriorityWorkChans []chan *InputLineWithNameServer
GlobalWorkChan chan *InputLineWithNameServer
}

// NewWorkerChans creates numChans priority worker channels and a global worker channel
func NewWorkerChans(numPriorityChans int) *WorkChans {
workerPools := make([]chan *InputLineWithNameServer, numPriorityChans)
for i := 0; i < numPriorityChans; i++ {
workerPools[i] = make(chan *InputLineWithNameServer, 1)
}
return &WorkChans{PriorityWorkChans: workerPools, GlobalWorkChan: make(chan *InputLineWithNameServer)}
}

// InputLineWithNameServer is a struct that contains a line of input and the name server to use for the lookup
// This name server is a "suggestion", --iterative lookups will ignore it as well as AXFR lookups
type InputLineWithNameServer struct {
Line string
NameServer *zdns.NameServer
}

// inputDeMultiplexer is a single goroutine that reads from the input channel and prioritizes sending work to it's respective
// prioritized input channel. If the priority channel is full, it will send the work to the global work channel for an idle thread
// to load balance. The goal is that a worker thread will tend to re-use their existing TCP/TLS/HTTPS connection, saving handshakes.
// Work Balancing
// The GlobalWorkChan is used to address work imbalance between worker pools. If a query should go to Priority Channel A but A is busy, it will go to the GlobalWorkChan
// Workers will check the GlobalWorkChan only if their Priority channel is empty. This means they will tend to re-use their connections, but help out other pools if they're idle
func inputDeMultiplexer(nameServers []zdns.NameServer, inChan <-chan string, workerPools *WorkChans, wg *sync.WaitGroup) error {
defer wg.Done()
defer func() {
// cleanup work channels
for _, pool := range workerPools.PriorityWorkChans {
close(pool)
}
close(workerPools.GlobalWorkChan)
}()
for line := range inChan {
nsIndex := rand.Intn(len(nameServers))
randomNS := nameServers[nsIndex]
chanID := nsIndex % len(workerPools.PriorityWorkChans)
work := &InputLineWithNameServer{Line: line, NameServer: &randomNS}
// for each work item, we prefer to send it to the assigned worker pool for the name server
select {
case workerPools.PriorityWorkChans[chanID] <- work: // prefer to send to the worker pool for the name server
default:
// worker pool is busy, we'll take first available spot between the global and priority channels
select {
case workerPools.GlobalWorkChan <- work:
case workerPools.PriorityWorkChans[chanID] <- work:
}
}
}
return nil
}

func Run(gc CLIConf) {
gc = *populateCLIConfig(&gc)
resolverConfig := populateResolverConfig(&gc)
Expand Down Expand Up @@ -468,20 +526,61 @@ func Run(gc CLIConf) {
log.Fatal("Output handler is nil")
}

nameServers := util.Concat(resolverConfig.ExternalNameServersV4, resolverConfig.ExternalNameServersV6, resolverConfig.RootNameServersV4, resolverConfig.RootNameServersV6)
// de-dupe
nsLookupMap := make(map[uint32]struct{})
uniqNameServers := make([]zdns.NameServer, 0, len(nameServers))
var hash uint32
for _, ns := range nameServers {
hash, err = ns.Hash()
if err != nil {
log.Fatalf("could not hash name server %s: %v", ns.String(), err)
}
if _, ok := nsLookupMap[hash]; !ok {
nsLookupMap[hash] = struct{}{}
uniqNameServers = append(uniqNameServers, ns)
}
}
// DoH lookups only depend on domain name, remove nameservers with duplicate domains
// We don't want the deMultiplexer to send the same domain (with different IPs) to different worker pools
if gc.DNSOverHTTPS {
nsDomainLookupMap := make(map[string]struct{})
uniqueDomainNSes := make([]zdns.NameServer, 0, len(uniqNameServers))
for _, ns := range uniqNameServers {
if _, ok := nsDomainLookupMap[ns.DomainName]; !ok {
nsDomainLookupMap[ns.DomainName] = struct{}{}
uniqueDomainNSes = append(uniqueDomainNSes, ns)
}
}
uniqNameServers = uniqueDomainNSes
}
numberOfPriorityChans := len(uniqNameServers)
if gc.Threads < numberOfPriorityChans {
// multiple threads can share a channel, but we can't have more channels than threads
numberOfPriorityChans = gc.Threads
}
workerPools := NewWorkerChans(numberOfPriorityChans)

// Use handlers to populate the input and output/results channel
go func() {
inErr := inHandler.FeedChannel(inChan, &routineWG)
if inErr != nil {
log.Fatal(fmt.Sprintf("could not feed input channel: %v", inErr))
}
}()
go func() {
plexErr := inputDeMultiplexer(uniqNameServers, inChan, workerPools, &routineWG)
if plexErr != nil {
log.Fatal(fmt.Sprintf("could not de-multiplex input channel: %v", plexErr))
}
}()
go func() {
outErr := outHandler.WriteResults(outChan, &routineWG)
if outErr != nil {
log.Fatal(fmt.Sprintf("could not write output results from output channel: %v", outErr))
}
}()
routineWG.Add(2)
routineWG.Add(3)

// create pool of worker goroutines
var lookupWG sync.WaitGroup
Expand All @@ -490,8 +589,10 @@ func Run(gc CLIConf) {
// create shared cache for all threads to share
for i := 0; i < gc.Threads; i++ {
i := i
// assign each worker to a priority channel, we'll loop around if we have more workers than channels
channelID := i % len(workerPools.PriorityWorkChans)
go func(threadID int) {
initWorkerErr := doLookupWorker(&gc, resolverConfig, inChan, outChan, metaChan, &lookupWG)
initWorkerErr := doLookupWorker(&gc, resolverConfig, workerPools.PriorityWorkChans[channelID], workerPools.GlobalWorkChan, outChan, metaChan, &lookupWG)
if initWorkerErr != nil {
log.Fatalf("could not start lookup worker #%d: %v", i, initWorkerErr)
}
Expand Down Expand Up @@ -542,112 +643,159 @@ func Run(gc CLIConf) {
}

// doLookupWorker is a single worker thread that processes lookups from the input channel. It calls wg.Done when it is finished.
func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, input <-chan string, output chan<- string, metaChan chan<- routineMetadata, wg *sync.WaitGroup) error {
func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, preferredWorkChan, globalWorkChan <-chan *InputLineWithNameServer, output chan<- string, metaChan chan<- routineMetadata, wg *sync.WaitGroup) error {
defer wg.Done()
resolver, err := zdns.InitResolver(rc)
if err != nil {
return fmt.Errorf("could not init resolver: %w", err)
}
var metadata routineMetadata
metadata.Status = make(map[zdns.Status]int)
for line := range input {
// we'll process each module sequentially, parallelism is per-domain
res := zdns.Result{Results: make(map[string]zdns.SingleModuleResult, len(gc.ActiveModules))}
// get the fields that won't change for each lookup module
rawName := ""
var nameServer *zdns.NameServer
var nameServers []zdns.NameServer
nameServerString := ""
var rank int
var entryMetadata string
if gc.AlexaFormat {
rawName, rank = parseAlexa(line)
res.AlexaRank = rank
} else if gc.MetadataFormat {
rawName, entryMetadata = parseMetadataInputLine(line)
res.Metadata = entryMetadata
} else if gc.NameServerMode {
nameServers, err = convertNameServerStringToNameServer(line, rc.IPVersionMode, rc.DNSOverTLS, rc.DNSOverHTTPS)
if err != nil {
log.Fatal("unable to parse name server: ", line)
var task *InputLineWithNameServer
var ok bool

WorkerLoop:
for {
// Check its own priority channel first to prioritize re-using TCP/HTTPS/TLS connections
select {
case task, ok = <-preferredWorkChan:
if !ok {
// inputDeMultiplexer has closed the channel, we're done
break WorkerLoop
}
if len(nameServers) == 0 {
log.Fatal("no name servers found in line: ", line)
}
// if user provides a domain name for the name server (one.one.one.one) we'll pick one of the IPs at random
nameServer = &nameServers[rand.Intn(len(nameServers))]
} else {
rawName, nameServerString = parseNormalInputLine(line)
if len(nameServerString) != 0 {
nameServers, err = convertNameServerStringToNameServer(nameServerString, rc.IPVersionMode, rc.DNSOverTLS, rc.DNSOverHTTPS)
if err != nil {
log.Fatal("unable to parse name server: ", line)
handleWorkerInput(gc, rc, task, resolver, &metadata, output)
default:
// wait on either Priority/Global channel
select {
case task, ok = <-preferredWorkChan:
if !ok {
break WorkerLoop
}
if len(nameServers) == 0 {
log.Fatal("no name servers found in line: ", line)
handleWorkerInput(gc, rc, task, resolver, &metadata, output)
case task, ok = <-globalWorkChan:
if !ok {
break WorkerLoop
}
// if user provides a domain name for the name server (one.one.one.one) we'll pick one of the IPs at random
nameServer = &nameServers[rand.Intn(len(nameServers))]
handleWorkerInput(gc, rc, task, resolver, &metadata, output)
}
}
res.Name = rawName
// handle per-module lookups
for moduleName, module := range gc.ActiveModules {
var innerRes interface{}
var trace zdns.Trace
var status zdns.Status
var err error
var changed bool
var lookupName string
lookupName, changed = makeName(rawName, gc.NamePrefix, gc.NameOverride)
if changed {
res.AlteredName = lookupName
}
// close the resolver, freeing up resources
resolver.Close()
metaChan <- metadata
return nil
}

func handleWorkerInput(gc *CLIConf, rc *zdns.ResolverConfig, line *InputLineWithNameServer, resolver *zdns.Resolver, metadata *routineMetadata, output chan<- string) {
// we'll process each module sequentially, parallelism is per-domain
res, rawName, nameServer := parseInputLine(gc, rc, line)
res.Name = rawName
// handle per-module lookups
for moduleName, module := range gc.ActiveModules {
if moduleName == "AXFR" {
// special case, AXFR has its own nameserver handling. We'll only take nameservers if the user provides it
// not the "suggestion" from the de-multiplexor
if nameServer.String() == line.NameServer.String() {
// this name server is the suggested one from the de-multiplexor, we'll remove it
nameServer = nil
}
res.Class = dns.Class(gc.Class).String()
}
var innerRes interface{}
var trace zdns.Trace
var status zdns.Status
var err error
var changed bool
var lookupName string
lookupName, changed = makeName(rawName, gc.NamePrefix, gc.NameOverride)
if changed {
res.AlteredName = lookupName
}
res.Class = dns.Class(gc.Class).String()

startTime := time.Now()
innerRes, trace, status, err = module.Lookup(resolver, lookupName, nameServer)
startTime := time.Now()
innerRes, trace, status, err = module.Lookup(resolver, lookupName, nameServer)

lookupRes := zdns.SingleModuleResult{
Timestamp: time.Now().Format(gc.TimeFormat),
Duration: time.Since(startTime).Seconds(),
}
if status != zdns.StatusNoOutput {
lookupRes.Status = string(status)
lookupRes.Data = innerRes
lookupRes.Trace = trace
if err != nil {
lookupRes.Error = err.Error()
}
res.Results[moduleName] = lookupRes
}
metadata.Status[status]++
metadata.Lookups++
}
if len(res.Results) > 0 {
v, _ := version.NewVersion("0.0.0")
o := &sheriff.Options{
Groups: gc.OutputGroups,
ApiVersion: v,
IncludeEmptyTag: true,
}
data, err := sheriff.Marshal(o, res)
lookupRes := zdns.SingleModuleResult{
Timestamp: time.Now().Format(gc.TimeFormat),
Duration: time.Since(startTime).Seconds(),
}
if status != zdns.StatusNoOutput {
lookupRes.Status = string(status)
lookupRes.Data = innerRes
lookupRes.Trace = trace
if err != nil {
log.Fatalf("unable to marshal result to JSON: %v", err)
lookupRes.Error = err.Error()
}
cleansedData := replaceIntSliceInterface(data)
jsonRes, err := json.Marshal(cleansedData)
res.Results[moduleName] = lookupRes
}
metadata.Status[status]++
metadata.Lookups++
}
if len(res.Results) > 0 {
v, _ := version.NewVersion("0.0.0")
o := &sheriff.Options{
Groups: gc.OutputGroups,
ApiVersion: v,
IncludeEmptyTag: true,
}
data, err := sheriff.Marshal(o, res)
if err != nil {
log.Fatalf("unable to marshal result to JSON: %v", err)
}
cleansedData := replaceIntSliceInterface(data)
jsonRes, err := json.Marshal(cleansedData)
if err != nil {
log.Fatalf("unable to marshal JSON result: %v", err)
}
output <- string(jsonRes)
}
metadata.Names++
}

func parseInputLine(gc *CLIConf, rc *zdns.ResolverConfig, line *InputLineWithNameServer) (*zdns.Result, string, *zdns.NameServer) {
res := zdns.Result{Results: make(map[string]zdns.SingleModuleResult, len(gc.ActiveModules))}
// get the fields that won't change for each lookup module
rawName := ""
// this is the name server "suggested" by the de-multiplexor. The goal is that if
// 1) the user doesn't provide a nameserver
// 2) we're in external lookup mode
// then we'll use the suggestion. This is to avoid the overhead of re-handshaking for each lookup
// it's overwritten if the user provides a nameserver as part of the input line below
nameServer := line.NameServer
nameServerString := ""
var rank int
var entryMetadata string
if gc.AlexaFormat {
rawName, rank = parseAlexa(line.Line)
res.AlexaRank = rank
} else if gc.MetadataFormat {
rawName, entryMetadata = parseMetadataInputLine(line.Line)
res.Metadata = entryMetadata
} else if gc.NameServerMode {
nameServers, err := convertNameServerStringToNameServer(line.Line, rc.IPVersionMode, rc.DNSOverTLS, rc.DNSOverHTTPS)
if err != nil {
log.Fatal("unable to parse name server: ", line.Line)
}
if len(nameServers) == 0 {
log.Fatal("no name servers found in line: ", line.Line)
}
// if user provides a domain name for the name server (one.one.one.one) we'll pick one of the IPs at random
nameServer = &nameServers[rand.Intn(len(nameServers))]
} else {
rawName, nameServerString = parseNormalInputLine(line.Line)
if len(nameServerString) != 0 {
nameServers, err := convertNameServerStringToNameServer(nameServerString, rc.IPVersionMode, rc.DNSOverTLS, rc.DNSOverHTTPS)
if err != nil {
log.Fatalf("unable to marshal JSON result: %v", err)
log.Fatal("unable to parse name server: ", line.Line)
}
if len(nameServers) == 0 {
log.Fatal("no name servers found in line: ", line.Line)
}
output <- string(jsonRes)
// if user provides a domain name for the name server (one.one.one.one) we'll pick one of the IPs at random
nameServer = &nameServers[rand.Intn(len(nameServers))]
}
metadata.Names++
}
// close the resolver, freeing up resources
resolver.Close()
metaChan <- metadata
return nil
return &res, rawName, nameServer
}

func parseAlexa(line string) (string, int) {
Expand Down
4 changes: 2 additions & 2 deletions src/modules/spf/spf.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func (spfMod *SpfLookupModule) CLIInit(gc *cli.CLIConf, rc *zdns.ResolverConfig)
return spfMod.BasicLookupModule.CLIInit(gc, rc)
}

func (spfMod *SpfLookupModule) Lookup(r *zdns.Resolver, name string, resolver *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error) {
innerRes, trace, status, err := spfMod.BasicLookupModule.Lookup(r, name, resolver)
func (spfMod *SpfLookupModule) Lookup(r *zdns.Resolver, name string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error) {
innerRes, trace, status, err := spfMod.BasicLookupModule.Lookup(r, name, nameServer)
castedInnerRes, ok := innerRes.(*zdns.SingleQueryResult)
if !ok {
return nil, trace, status, errors.New("lookup didn't return a single query result type")
Expand Down
Loading