diff --git a/makefile b/makefile index 0a1f60c4..395d412f 100644 --- a/makefile +++ b/makefile @@ -12,7 +12,7 @@ install: zdns test: zdns go test -v ./... pip3 install -r testing/requirements.txt - pytest -n auto testing/integration_tests.py + pytest -n 4 testing/integration_tests.py integration-tests: zdns pip3 install -r testing/requirements.txt diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index d39009e8..04312715 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -38,12 +38,6 @@ import ( "github.com/zmap/zdns/src/zdns" ) -const ( - googleDNSResolverAddr = "8.8.8.8:53" - googleDNSResolverAddrV6 = "[2001:4860:4860::8888]:53" - loopbackIPv4Addr = "127.0.0.1" -) - type routineMetadata struct { Names int // number of domain names processed Lookups int // number of lookups performed @@ -236,14 +230,12 @@ func populateResolverConfig(gc *CLIConf) *zdns.ResolverConfig { config.RootNameServersV4 = []zdns.NameServer{} } noV4NameServers := len(config.ExternalNameServersV4) == 0 && len(config.RootNameServersV4) == 0 - if config.IPVersionMode != zdns.IPv6Only && noV4NameServers { - log.Info("no IPv4 nameservers found. Switching to --6 only") - config.IPVersionMode = zdns.IPv6Only + if gc.IPv4TransportOnly && noV4NameServers { + log.Fatal("cannot use --4 since no IPv4 nameservers found, ensure you have IPv4 connectivity and provide --name-servers") } noV6NameServers := len(config.ExternalNameServersV6) == 0 && len(config.RootNameServersV6) == 0 - if config.IPVersionMode != zdns.IPv4Only && noV6NameServers { - log.Info("no IPv6 nameservers found. Switching to --4 only") - config.IPVersionMode = zdns.IPv4Only + if gc.IPv6TransportOnly && noV6NameServers { + log.Fatal("cannot use --6 since no IPv6 nameservers found, ensure you have IPv6 connectivity and provide --name-servers") } config, err = populateLocalAddresses(gc, config) @@ -255,7 +247,9 @@ func populateResolverConfig(gc *CLIConf) *zdns.ResolverConfig { // populateIPTransportMode populates the IPTransportMode field of the ResolverConfig // If user sets --4 (IPv4 Only) or --6 (IPv6 Only), we'll set the IPVersionMode to IPv4Only or IPv6Only, respectively. -// Otherwise, we need to determine the IPVersionMode based on either the OS' default resolver(s) or the user's provided name servers. +// If user does not set --4 or --6, we'll determine the IPVersionMode based on: +// 1. the provided name-servers (if any) +// 2. the OS' default resolvers (if no name-servers provided) func populateIPTransportMode(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.ResolverConfig, error) { if gc.IPv4TransportOnly && gc.IPv6TransportOnly { return nil, errors.New("only one of --4 and --6 allowed") @@ -387,10 +381,8 @@ func populateNameServers(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Resolv func populateLocalAddresses(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.ResolverConfig, error) { // Local Addresses are populated in this order: // 1. If user provided local addresses, use those - // 2. If the config's nameservers are loopback, use the local loopback address - // 3. Otherwise, try to connect to Google's recursive resolver and take the IP address we use for the connection + // 2. If user does not provide local addresses, one will be used on-demand by Resolver. See resolver.go:getConnectionInfo for more info - // IPv4 local addresses are required for IPv4 lookups, same for IPv6 if len(gc.LocalAddrs) != 0 { // if user provided a local address(es), that takes precedent config.LocalAddrsV4, config.LocalAddrsV6 = []net.IP{}, []net.IP{} @@ -406,50 +398,6 @@ func populateLocalAddresses(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Res return nil, fmt.Errorf("invalid local address: %s", addr.String()) } } - return config, nil - } - // if the nameservers are loopback, use the loopback address - allNameServers := util.Concat(config.ExternalNameServersV4, config.ExternalNameServersV6, config.RootNameServersV4, config.RootNameServersV6) - if len(allNameServers) == 0 { - // this shouldn't happen - return nil, errors.New("name servers must be set before populating local addresses") - } - anyNameServersLoopack := false - for _, ns := range allNameServers { - if ns.IP.IsLoopback() { - anyNameServersLoopack = true - break - } - } - - if anyNameServersLoopack { - // set local address so name servers are reachable - config.LocalAddrsV4 = []net.IP{net.ParseIP(loopbackIPv4Addr)} - // loopback nameservers not supported for IPv6, we'll let Resolver validation take care of this - } else { - // localAddr not set, so we need to find the default IP address - if config.IPVersionMode != zdns.IPv6Only { - conn, err := net.Dial("udp", googleDNSResolverAddr) - if err != nil { - return nil, fmt.Errorf("unable to find default IP address to open socket: %w", err) - } - config.LocalAddrsV4 = []net.IP{conn.LocalAddr().(*net.UDPAddr).IP} - // cleanup socket - if err = conn.Close(); err != nil { - log.Error("unable to close test connection to Google public DNS: ", err) - } - } - if config.IPVersionMode != zdns.IPv4Only { - conn, err := net.Dial("udp", googleDNSResolverAddrV6) - if err != nil { - return nil, fmt.Errorf("unable to find default IP address to open socket: %w", err) - } - config.LocalAddrsV6 = []net.IP{conn.LocalAddr().(*net.UDPAddr).IP} - // cleanup socket - if err = conn.Close(); err != nil { - log.Error("unable to close test connection to Google public IPv6 DNS: ", err) - } - } } return config, nil } diff --git a/src/modules/nslookup/ns_lookup.go b/src/modules/nslookup/ns_lookup.go index d0e58cd3..f25ca877 100644 --- a/src/modules/nslookup/ns_lookup.go +++ b/src/modules/nslookup/ns_lookup.go @@ -38,7 +38,7 @@ type NSLookupModule struct { // CLIInit initializes the NSLookupModule with the given parameters, used to call NSLookup from the command line func (nsMod *NSLookupModule) CLIInit(gc *cli.CLIConf, resolverConf *zdns.ResolverConfig) error { if !nsMod.IPv4Lookup && !nsMod.IPv6Lookup { - log.Debug("NSModule: No IP version specified, defaulting to IPv4") + log.Debug("NSModule: neither --ipv4-lookup nor --ipv6-lookup specified, will only request A records for each NS server") nsMod.IPv4Lookup = true } err := nsMod.BasicLookupModule.CLIInit(gc, resolverConf) diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index 08db3b0f..e91944d0 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -427,22 +427,14 @@ func (r *Resolver) retryingLookup(ctx context.Context, q Question, nameServer *N if nameServer == nil { return SingleQueryResult{}, StatusIllegalInput, 0, errors.New("no nameserver specified") } - var connInfo *ConnectionInfo - if nameServer.IP.To4() != nil { - connInfo = r.connInfoIPv4 - } else if util.IsIPv6(&nameServer.IP) { - connInfo = r.connInfoIPv6 - } else { - return SingleQueryResult{}, StatusError, 0, fmt.Errorf("could not determine IP version of nameserver: %s", nameServer) + connInfo, err := r.getConnectionInfo(nameServer) + if err != nil { + return SingleQueryResult{}, StatusError, 0, fmt.Errorf("could not get a connection info to query nameserver %s: %v", nameServer, err) } // check that our connection info is valid if connInfo == nil { return SingleQueryResult{}, StatusError, 0, fmt.Errorf("no connection info for nameserver: %s", nameServer) } - // check loopback consistency - if nameServer.IP.IsLoopback() != connInfo.localAddr.IsLoopback() { - return SingleQueryResult{}, StatusIllegalInput, 0, fmt.Errorf("nameserver %s must be reachable from the local address %s, ie. both must be loopback or not loopback", nameServer, connInfo.localAddr.String()) - } r.verboseLog(1, "****WIRE LOOKUP*** ", dns.TypeToString[q.Type], " ", q.Name, " ", nameServer) for i := 0; i <= r.retries; i++ { // check context before going into wireLookup diff --git a/src/zdns/lookup_test.go b/src/zdns/lookup_test.go index 79f45269..2dbaf2b8 100644 --- a/src/zdns/lookup_test.go +++ b/src/zdns/lookup_test.go @@ -1980,13 +1980,6 @@ func TestInvalidInputsLookup(t *testing.T) { _, _, _, err := resolver.ExternalLookup(&q, &NameServer{IP: net.ParseIP("127.0.0.53")}) assert.Nil(t, err) }) - t.Run("using a loopback local addr with non-loopback nameserver", func(t *testing.T) { - result, trace, status, err := resolver.ExternalLookup(&q, &NameServer{IP: net.ParseIP("1.1.1.1"), Port: 53}) - assert.Nil(t, result) - assert.Nil(t, trace) - assert.Equal(t, StatusIllegalInput, status) - assert.NotNil(t, err) - }) t.Run("invalid nameserver address", func(t *testing.T) { result, trace, status, err := resolver.ExternalLookup(&q, &NameServer{IP: net.ParseIP("987.987.987.987"), Port: 53}) assert.Nil(t, result) diff --git a/src/zdns/resolver.go b/src/zdns/resolver.go index 5f534427..8b91f4f2 100644 --- a/src/zdns/resolver.go +++ b/src/zdns/resolver.go @@ -48,6 +48,8 @@ const ( defaultIterationIPPreference = PreferIPv4 DefaultNameServerConfigFile = "/etc/resolv.conf" defaultLookupAllNameServers = false + DefaultLoopbackIPv4Addr = "127.0.0.1" + DefaultLoopbackIPv6Addr = "::1" ) // ResolverConfig is a struct that holds all the configuration options for a Resolver. It is used to create a new Resolver. @@ -131,14 +133,6 @@ func (rc *ResolverConfig) Validate() error { } } - // Local Addresses - if rc.IPVersionMode != IPv6Only && len(rc.LocalAddrsV4) == 0 { - return errors.New("must have a local IPv4 address to send traffic from") - } - if rc.IPVersionMode != IPv4Only && len(rc.LocalAddrsV6) == 0 { - return errors.New("must have a local IPv6 address to send traffic from") - } - // Validate all local addresses are valid IPs for _, ip := range util.Concat(rc.LocalAddrsV4, rc.LocalAddrsV6) { if ip == nil { @@ -176,38 +170,6 @@ func (rc *ResolverConfig) Validate() error { return fmt.Errorf("link-local IPv6 external/root nameservers are not supported: %v", ns.IP) } } - - if err := rc.validateLoopbackConsistency(); err != nil { - return errors.Wrap(err, "could not validate loopback consistency") - } - - return nil -} - -// validateLoopbackConsistency checks that the following is true -// - either all nameservers AND all local addresses are loopback, or none are -func (rc *ResolverConfig) validateLoopbackConsistency() error { - allLocalAddrs := util.Concat(rc.LocalAddrsV4, rc.LocalAddrsV6) - allExternalNameServers := util.Concat(rc.ExternalNameServersV4, rc.ExternalNameServersV6) - allRootNameServers := util.Concat(rc.RootNameServersV4, rc.RootNameServersV6) - allIPsLength := len(allLocalAddrs) + len(allExternalNameServers) + len(allRootNameServers) - allIPs := make([]net.IP, 0, allIPsLength) - allIPs = append(allIPs, allLocalAddrs...) - for _, ns := range util.Concat(allExternalNameServers, allRootNameServers) { - allIPs = append(allIPs, ns.IP) - } - allIPsLoopback := true - noneIPsLoopback := true - for _, ip := range allIPs { - if ip.IsLoopback() { - noneIPsLoopback = false - } else { - allIPsLoopback = false - } - } - if allIPsLoopback == noneIPsLoopback { - return fmt.Errorf("cannot mix loopback and non-loopback local addresses (%v) and name servers (%v)", allLocalAddrs, util.Concat(allExternalNameServers, allRootNameServers)) - } return nil } @@ -270,10 +232,13 @@ type Resolver struct { cache *Cache lookupClient Lookuper // either a functional or mock Lookuper client for testing - blacklist *blacklist.SafeBlacklist - - connInfoIPv4 *ConnectionInfo - connInfoIPv6 *ConnectionInfo + blacklist *blacklist.SafeBlacklist + userPreferredIPv4LocalAddrs []net.IP // user-supplied local IPv4 addresses, we'll prefer to use these + userPreferredIPv6LocalAddrs []net.IP // user-supplied local IPv6 addresses, we'll prefer to use these + connInfoIPv4Internet *ConnectionInfo // used for IPv4 lookups to Internet-facing nameservers + connInfoIPv6Internet *ConnectionInfo // used for IPv6 lookups to Internet-facing nameservers + connInfoIPv4Loopback *ConnectionInfo // used for IPv4 lookups to loopback nameservers + connInfoIPv6Loopback *ConnectionInfo // used for IPv6 lookups to loopback nameservers retries int logLevel log.Level @@ -338,22 +303,9 @@ func InitResolver(config *ResolverConfig) (*Resolver, error) { checkingDisabledBit: config.CheckingDisabledBit, } log.SetLevel(r.logLevel) - if config.IPVersionMode != IPv6Only { - // create connection info for IPv4 - connInfo, err := getConnectionInfo(config.LocalAddrsV4, config.TransportMode, config.Timeout, config.ShouldRecycleSockets) - if err != nil { - return nil, fmt.Errorf("could not create connection info for IPv4: %w", err) - } - r.connInfoIPv4 = connInfo - } - if config.IPVersionMode != IPv4Only { - // create connection info for IPv6 - connInfo, err := getConnectionInfo(config.LocalAddrsV6, config.TransportMode, config.Timeout, config.ShouldRecycleSockets) - if err != nil { - return nil, fmt.Errorf("could not create connection info for IPv6: %w", err) - } - r.connInfoIPv6 = connInfo - } + // Deep copy local address so Resolver is independent of the config + r.userPreferredIPv4LocalAddrs = DeepCopyIPs(config.LocalAddrsV4) + r.userPreferredIPv6LocalAddrs = DeepCopyIPs(config.LocalAddrsV6) // need to deep-copy here so we're not reliant on the state of the resolver config post-resolver creation r.externalNameServers = make([]NameServer, 0, len(config.ExternalNameServersV4)+len(config.ExternalNameServersV6)) if config.IPVersionMode == IPv4Only || config.IPVersionMode == IPv4OrIPv6 { @@ -394,11 +346,77 @@ func InitResolver(config *ResolverConfig) (*Resolver, error) { return r, nil } -func getConnectionInfo(localAddr []net.IP, transportMode transportMode, timeout time.Duration, shouldRecycleSockets bool) (*ConnectionInfo, error) { +// getConnectionInfo uses the name server to determine if a loopback vs. non-loopback or IPv4/v6 connection should be used +// If the Resolver does not have a connection info for the name server, it will create one. +// ConnectionInfo objects are created on an as-needed basis +func (r *Resolver) getConnectionInfo(nameServer *NameServer) (*ConnectionInfo, error) { + // what local addresses should we use? + isNSIPv6 := util.IsIPv6(&nameServer.IP) + isLoopback := nameServer.IP.IsLoopback() + // check if we have a pre-existing conn info + if isNSIPv6 && isLoopback && r.connInfoIPv6Loopback != nil { + return r.connInfoIPv6Loopback, nil + } else if isNSIPv6 && !isLoopback && r.connInfoIPv6Internet != nil { + return r.connInfoIPv6Internet, nil + } else if !isNSIPv6 && isLoopback && r.connInfoIPv4Loopback != nil { + return r.connInfoIPv4Loopback, nil + } else if !isNSIPv6 && !isLoopback && r.connInfoIPv4Internet != nil { + // must be IPv4 non-loopback + return r.connInfoIPv4Internet, nil + } + + // no existing ConnInfo, create a new one + // r.localAddrs contain either user-supplied or default local addresses + // If one satisfying our conditions is available, use it. + var userIPs []net.IP + if isNSIPv6 { + userIPs = r.userPreferredIPv6LocalAddrs + } else { + userIPs = r.userPreferredIPv4LocalAddrs + } + // Shuffle the slice in random order so that we don't always use the same local address + rand.Shuffle(len(userIPs), func(i, j int) { + userIPs[i], userIPs[j] = userIPs[j], userIPs[i] + }) + var localAddr *net.IP + for _, ip := range userIPs { + if isLoopback == ip.IsLoopback() { + localAddr = &ip + break + } + } + + if localAddr == nil { + // none of the user-supplied IPs match the conditions, we need to select one + if isLoopback && isNSIPv6 { + ip := net.ParseIP(DefaultLoopbackIPv6Addr) + localAddr = &ip + } else if isLoopback { + ip := net.ParseIP(DefaultLoopbackIPv4Addr) + localAddr = &ip + } else { + // non-loopback, attempt to reach the nameserver from the internet and get the local addr. used + conn, err := net.Dial("udp", nameServer.String()) + if err != nil { + return nil, fmt.Errorf("unable to find default IP address to open socket: %w", err) + } + localAddr = &conn.LocalAddr().(*net.UDPAddr).IP + // cleanup socket + if err = conn.Close(); err != nil { + log.Error("unable to close test connection to Google public DNS: ", err) + } + } + if localAddr != nil { + log.Infof("none of the user-supplied local addresses could connect to name server %s, using local address %s", nameServer.String(), localAddr.String()) + } + } + if localAddr == nil { + return nil, errors.New("unable to find local address for connection") + } connInfo := &ConnectionInfo{ - localAddr: localAddr[rand.Intn(len(localAddr))], + localAddr: *localAddr, } - if shouldRecycleSockets { + if r.shouldRecycleSockets { // create persistent connection conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: connInfo.localAddr}) if err != nil { @@ -408,25 +426,35 @@ func getConnectionInfo(localAddr []net.IP, transportMode transportMode, timeout connInfo.conn.Conn = conn } - usingUDP := transportMode == UDPOrTCP || transportMode == UDPOnly + usingUDP := r.transportMode == UDPOrTCP || r.transportMode == UDPOnly if usingUDP { connInfo.udpClient = new(dns.Client) - connInfo.udpClient.Timeout = timeout + connInfo.udpClient.Timeout = r.timeout connInfo.udpClient.Dialer = &net.Dialer{ - Timeout: timeout, + Timeout: r.timeout, LocalAddr: &net.UDPAddr{IP: connInfo.localAddr}, } } - usingTCP := transportMode == UDPOrTCP || transportMode == TCPOnly + usingTCP := r.transportMode == UDPOrTCP || r.transportMode == TCPOnly if usingTCP { connInfo.tcpClient = new(dns.Client) connInfo.tcpClient.Net = "tcp" - connInfo.tcpClient.Timeout = timeout + connInfo.tcpClient.Timeout = r.timeout connInfo.tcpClient.Dialer = &net.Dialer{ - Timeout: timeout, + Timeout: r.timeout, LocalAddr: &net.TCPAddr{IP: connInfo.localAddr}, } } + // save the connection info for future use + if isNSIPv6 && isLoopback { + r.connInfoIPv6Loopback = connInfo + } else if isNSIPv6 { + r.connInfoIPv6Internet = connInfo + } else if isLoopback { + r.connInfoIPv4Loopback = connInfo + } else { + r.connInfoIPv4Internet = connInfo + } return connInfo, nil } @@ -448,18 +476,7 @@ func (r *Resolver) ExternalLookup(q *Question, dstServer *NameServer) (*SingleQu } dstServer.PopulateDefaultPort() if isValid, reason := dstServer.IsValid(); !isValid { - return nil, nil, StatusIllegalInput, fmt.Errorf("could not parse name server (%s): %s", dstServer.String(), reason) - } - if util.IsIPv6(&dstServer.IP) && r.connInfoIPv6 == nil { - return nil, nil, StatusIllegalInput, fmt.Errorf("IPv6 external lookup requested for domain %s but no IPv6 local addresses provided to resolver", q.Name) - } else if dstServer.IP.To4() != nil && r.connInfoIPv4 == nil { - return nil, nil, StatusIllegalInput, fmt.Errorf("IPv4 external lookup requested for domain %s but no IPv4 local addresses provided to resolver", q.Name) - } - // check that local address and dstServer's don't have a loopback mismatch - if dstServer.IP.To4() != nil && r.connInfoIPv4.localAddr.IsLoopback() != dstServer.IP.IsLoopback() { - return nil, nil, StatusIllegalInput, errors.New("cannot mix loopback and non-loopback addresses") - } else if util.IsIPv6(&dstServer.IP) && r.connInfoIPv6.localAddr.IsLoopback() != dstServer.IP.IsLoopback() { - return nil, nil, StatusIllegalInput, errors.New("cannot mix loopback and non-loopback addresses") + return nil, nil, StatusIllegalInput, fmt.Errorf("destination server %s is invalid: %s", dstServer.String(), reason) } // dstServer has been validated and has a port, continue with lookup lookup, trace, status, err := r.lookupClient.DoSingleDstServerLookup(r, *q, dstServer, false) @@ -482,16 +499,26 @@ func (r *Resolver) IterativeLookup(q *Question) (*SingleQueryResult, Trace, Stat // Close cleans up any resources used by the resolver. This should be called when the resolver is no longer needed. // Lookup will panic if called after Close. func (r *Resolver) Close() { - if r.connInfoIPv4 != nil && r.connInfoIPv4.conn != nil { - if err := r.connInfoIPv4.conn.Close(); err != nil { + if r.connInfoIPv4Internet != nil && r.connInfoIPv4Internet.conn != nil { + if err := r.connInfoIPv4Internet.conn.Close(); err != nil { log.Errorf("error closing IPv4 connection: %v", err) } } - if r.connInfoIPv6 != nil && r.connInfoIPv6.conn != nil { - if err := r.connInfoIPv6.conn.Close(); err != nil { + if r.connInfoIPv6Internet != nil && r.connInfoIPv6Internet.conn != nil { + if err := r.connInfoIPv6Internet.conn.Close(); err != nil { log.Errorf("error closing IPv6 connection: %v", err) } } + if r.connInfoIPv4Loopback != nil && r.connInfoIPv4Loopback.conn != nil { + if err := r.connInfoIPv4Loopback.conn.Close(); err != nil { + log.Errorf("error closing IPv4 loopback connection: %v", err) + } + } + if r.connInfoIPv6Loopback != nil && r.connInfoIPv6Loopback.conn != nil { + if err := r.connInfoIPv6Loopback.conn.Close(); err != nil { + log.Errorf("error closing IPv6 loopback connection: %v", err) + } + } } func (r *Resolver) randomExternalNameServer() *NameServer { diff --git a/src/zdns/resolver_test.go b/src/zdns/resolver_test.go index fa215566..009d77e2 100644 --- a/src/zdns/resolver_test.go +++ b/src/zdns/resolver_test.go @@ -65,49 +65,4 @@ func TestResolverConfig_Validate(t *testing.T) { err := rc.Validate() require.NotNil(t, err) }) - t.Run("Missing local addr", func(t *testing.T) { - rc := &ResolverConfig{ - ExternalNameServersV4: []NameServer{{IP: net.ParseIP("127.0.0.53"), Port: 53}}, - RootNameServersV4: []NameServer{{IP: net.ParseIP("127.0.0.53"), Port: 53}}, - } - err := rc.Validate() - require.NotNil(t, err) - }) - - t.Run("Cannot mix loopback addresses in nameservers", func(t *testing.T) { - rc := &ResolverConfig{ - ExternalNameServersV4: []NameServer{{IP: net.ParseIP("127.0.0.53"), Port: 53}, {IP: net.ParseIP("1.1.1.1"), Port: 53}}, - RootNameServersV4: []NameServer{{IP: net.ParseIP("127.0.0.53"), Port: 53}}, - LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, - } - err := rc.Validate() - require.NotNil(t, err) - }) - t.Run("Cannot mix loopback addresses among nameservers", func(t *testing.T) { - rc := &ResolverConfig{ - ExternalNameServersV4: []NameServer{{IP: net.ParseIP("1.1.1.1"), Port: 53}}, - RootNameServersV4: []NameServer{{IP: net.ParseIP("127.0.0.53"), Port: 53}}, - LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, - } - err := rc.Validate() - require.NotNil(t, err) - }) - t.Run("Cannot reach loopback NSes from non-loopback local address", func(t *testing.T) { - rc := &ResolverConfig{ - ExternalNameServersV4: []NameServer{{IP: net.ParseIP("127.0.0.53"), Port: 53}}, - RootNameServersV4: []NameServer{{IP: net.ParseIP("127.0.0.53"), Port: 53}}, - LocalAddrsV4: []net.IP{net.ParseIP("192.168.1.2")}, - } - err := rc.Validate() - require.NotNil(t, err) - }) - t.Run("Cannot reach non-loopback NSes from loopback local address", func(t *testing.T) { - rc := &ResolverConfig{ - ExternalNameServersV4: []NameServer{{IP: net.ParseIP("1.1.1.1"), Port: 53}}, - RootNameServersV4: []NameServer{{IP: net.ParseIP("1.1.1.1"), Port: 53}}, - LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, - } - err := rc.Validate() - require.NotNil(t, err) - }) } diff --git a/src/zdns/util.go b/src/zdns/util.go index d5571c5d..9b45c66f 100644 --- a/src/zdns/util.go +++ b/src/zdns/util.go @@ -184,3 +184,15 @@ func handleStatus(status Status, err error) (Status, error) { return s, nil } } + +// DeepCopyIPs creates a deep copy of a slice of net.IP +func DeepCopyIPs(ips []net.IP) []net.IP { + copied := make([]net.IP, len(ips)) + for i, ip := range ips { + if ip != nil { + // Deep copy the IP by copying the underlying byte slice + copied[i] = append(net.IP(nil), ip...) + } + } + return copied +} diff --git a/testing/integration_tests.py b/testing/integration_tests.py index f61da428..2ee47736 100755 --- a/testing/integration_tests.py +++ b/testing/integration_tests.py @@ -1325,5 +1325,33 @@ def test_a_lookup_domain_name_server_with_input(self): self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd, "A") + def test_a_lookup_IP_name_server_with_input(self): + c = "A" + name = "zdns-testing.com,1.1.1.1" + cmd, res = self.run_zdns(c, name) + self.assertSuccess(res, cmd, "A") + self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd, "A") + self.assertEqual(res["results"]["A"]["data"]["resolver"], "1.1.1.1:53") + + def test_a_lookup_IP_name_server_with_input_flag_mismatch(self): + c = "A --name-servers=1.1.1.1" + name = "zdns-testing.com,8.8.8.8" + cmd, res = self.run_zdns(c, name) + self.assertSuccess(res, cmd, "A") + self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd, "A") + self.assertEqual(res["results"]["A"]["data"]["resolver"], "8.8.8.8:53", "user-supplied name server with input " + "should take precedence") + + def test_a_lookup_IP_name_server_with_input_flag_loopback_mismatch(self): + c = "A --name-servers=127.0.0.1" + name = "zdns-testing.com,8.8.8.8" + cmd, res = self.run_zdns(c, name) + self.assertSuccess(res, cmd, "A") + self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd, "A") + self.assertEqual(res["results"]["A"]["data"]["resolver"], "8.8.8.8:53", "user-supplied name server with input " + "should take precedence") + + + if __name__ == "__main__": unittest.main() diff --git a/testing/ipv6_tests.py b/testing/ipv6_tests.py index 3205a7b4..48c05d48 100644 --- a/testing/ipv6_tests.py +++ b/testing/ipv6_tests.py @@ -21,37 +21,50 @@ def run_zdns(self, flags, name, executable=ZDNS_EXECUTABLE): o = subprocess.check_output(c, shell=True) return c, json.loads(o.rstrip()) - def assertSuccess(self, res, cmd): - self.assertEqual(res["status"], "NOERROR", cmd) + def assertSuccess(self, res, cmd, query_type): + self.assertEqual(res["results"][query_type]["status"], "NOERROR", cmd) - def assertServFail(self, res, cmd): - self.assertEqual(res["status"], "SERVFAIL", cmd) + def assertServFail(self, res, cmd, query_type): + self.assertEqual(res["results"][query_type]["status"], "SERVFAIL", cmd) - def assertEqualAnswers(self, res, correct, cmd, key="answer"): - self.assertIn("answers", res["data"]) - for answer in res["data"]["answers"]: + def assertEqualAnswers(self, res, correct, cmd, query_type, key="answer"): + self.assertIn("answers", res["results"][query_type]["data"]) + for answer in res["results"][query_type]["data"]["answers"]: del answer["ttl"] - a = sorted(res["data"]["answers"], key=lambda x: x[key]) + a = sorted(res["results"][query_type]["data"]["answers"], key=lambda x: x[key]) b = sorted(correct, key=lambda x: x[key]) helptext = "%s\nExpected:\n%s\n\nActual:\n%s" % (cmd, json.dumps(b, indent=4), json.dumps(a, indent=4)) + def _lowercase(obj): + """ Make dictionary lowercase """ + if isinstance(obj, dict): + for k, v in obj.items(): + if k == "name": + obj[k] = v.lower() + else: + _lowercase(v) + + _lowercase(a) + _lowercase(b) + self.assertEqual(a, b, helptext) + def test_a_ipv6(self): - c = "A --name-servers=[2001:4860:4860::8888]:53" + c = "A --name-servers='[2001:4860:4860::8888]:53'" name = "zdns-testing.com" cmd, res = self.run_zdns(c, name) - self.assertSuccess(res, cmd) - self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd) + self.assertSuccess(res, cmd, "A") + self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd, "A") def test_ipv6_unreachable(self): c = "A --iterative --6" name = "esrg.stanford.edu" cmd, res = self.run_zdns(c, name) # esrg.stanford.edu is hosted on NS's that do not have an IPv6 address. Therefore, the lookup won't get sufficient glue records to resolve the query. - self.assertEqual(res["status"], "NONEEDEDGLUE", cmd) + self.assertEqual(res["results"]["A"]["status"], "NONEEDEDGLUE", cmd) def test_ipv6_external_lookup_unreachable_nameserver(self): - c = "A --6=true --4=false --name-servers=1.1.1.1" + c = "A --6 --name-servers=1.1.1.1" name = "zdns-testing.com" try: cmd, res = self.run_zdns(c, name) @@ -60,7 +73,7 @@ def test_ipv6_external_lookup_unreachable_nameserver(self): self.fail("Should have thrown an exception, shouldn't be able to reach any IPv4 servers while in IPv6 mode") def test_ipv4_external_lookup_unreachable_nameserver(self): - c = "A --6=false --4=true --name-servers=2606:4700:4700::1111" + c = "A --4 --name-servers=2606:4700:4700::1111" name = "zdns-testing.com" try: cmd, res = self.run_zdns(c, name) @@ -68,35 +81,26 @@ def test_ipv4_external_lookup_unreachable_nameserver(self): return True self.fail("Should have thrown an exception, shouldn't be able to reach any IPv6 servers while in IPv4 mode") - def test_ipv6_external_lookup_loopback_nameserver(self): - c = "A --6=true --4=false --name-servers=[::1]:53" - name = "zdns-testing.com" - try: - cmd, res = self.run_zdns(c, name) - except Exception as e: - return True - self.fail("Should have thrown an exception, shouldn't be able to use a loopback address as a nameserver in IPv6") - def test_ipv6_happy_path_external(self): - c = "A --6=true" + c = "A --6" name = "zdns-testing.com" cmd, res = self.run_zdns(c, name) - self.assertSuccess(res, cmd) - self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd) + self.assertSuccess(res, cmd, "A") + self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd, "A") def test_ipv6_happy_path_iterative(self): - c = "A --6=true --iterative" + c = "A --6 --iterative" name = "zdns-testing.com" cmd, res = self.run_zdns(c, name) - self.assertSuccess(res, cmd) - self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd) + self.assertSuccess(res, cmd, "A") + self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd, "A") def test_ipv6_happy_path_no_ipv4_iterative(self): - c = "A --6=true --4=false --iterative" + c = "A --6 --iterative" name = "zdns-testing.com" cmd, res = self.run_zdns(c, name) - self.assertSuccess(res, cmd) - self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd) + self.assertSuccess(res, cmd, "A") + self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd, "A") if __name__ == "__main__":