diff --git a/makefile b/makefile index 52157729..ef12dc53 100644 --- a/makefile +++ b/makefile @@ -17,6 +17,11 @@ integration-tests: zdns python3 testing/integration_tests.py python3 testing/large_scan_integration/large_scan_integration_tests.py +# Not all hosts support this, so this will be a custom make target +ipv6-tests: zdns + pip3 install -r testing/requirements.txt + python3 testing/ipv6_tests.py + lint: goimports -w -local "github.com/zmap/zdns" ./ gofmt -s -w ./ diff --git a/src/cli/alookup.go b/src/cli/alookup.go index 755fcb1e..421d134e 100644 --- a/src/cli/alookup.go +++ b/src/cli/alookup.go @@ -11,6 +11,7 @@ * implied. See the License for the specific language governing * permissions and limitations under the License. */ + package cli import ( @@ -40,8 +41,8 @@ Specifically, alookup acts similar to nslookup and will follow CNAME records.`, func init() { rootCmd.AddCommand(alookupCmd) - alookupCmd.PersistentFlags().Bool("ipv4-lookup", false, "perform A lookups for each MX server") - alookupCmd.PersistentFlags().Bool("ipv6-lookup", false, "perform AAAA record lookups for each MX server") + alookupCmd.PersistentFlags().Bool("ipv4-lookup", false, "perform A lookups for each server") + alookupCmd.PersistentFlags().Bool("ipv6-lookup", false, "perform AAAA record lookups for each server") util.BindFlags(alookupCmd, viper.GetViper(), util.EnvPrefix) } diff --git a/src/cli/cli.go b/src/cli/cli.go index b8f96519..3cb9befa 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -69,6 +69,10 @@ type CLIConf struct { LookupAllNameServers bool TCPOnly bool UDPOnly bool + IPv4TransportOnly bool // IPv4 transport only, incompatible with IPv6 transport only + IPv6TransportOnly bool // IPv6 transport only, incompatible with IPv4 transport only + PreferIPv4Iteration bool // Prefer IPv4/A record lookups during iterative resolution, only used if both IPv4 and IPv6 transport are enabled + PreferIPv6Iteration bool // Prefer IPv6/AAAA record lookups during iterative resolution, only used if both IPv4 and IPv6 transport are enabled RecycleSockets bool LocalAddrSpecified bool LocalAddrs []net.IP @@ -165,6 +169,11 @@ func init() { rootCmd.PersistentFlags().StringVar(&GC.NameServersString, "name-servers", "", "List of DNS servers to use. Can be passed as comma-delimited string or via @/path/to/file. If no port is specified, defaults to 53.") rootCmd.PersistentFlags().StringVar(&GC.LocalAddrString, "local-addr", "", "comma-delimited list of local addresses to use, serve as the source IP for outbound queries") rootCmd.PersistentFlags().StringVar(&GC.LocalIfaceString, "local-interface", "", "local interface to use") + rootCmd.PersistentFlags().BoolVar(&GC.IPv4TransportOnly, "4", false, "utilize IPv4 query transport only, incompatible with --6") + rootCmd.PersistentFlags().BoolVar(&GC.IPv6TransportOnly, "6", false, "utilize IPv6 query transport only, incompatible with --4") + rootCmd.PersistentFlags().BoolVar(&GC.PreferIPv4Iteration, "prefer-ipv4-iteration", false, "Prefer IPv4/A record lookups during iterative resolution. Ignored unless used with both IPv4 and IPv6") + rootCmd.PersistentFlags().BoolVar(&GC.PreferIPv6Iteration, "prefer-ipv6-iteration", false, "Prefer IPv6/AAAA record lookups during iterative resolution. Ignored unless used with both IPv4 and IPv6") + rootCmd.PersistentFlags().StringVar(&GC.ConfigFilePath, "conf-file", zdns.DefaultNameServerConfigFile, "config file for DNS servers") rootCmd.PersistentFlags().IntVar(&GC.Timeout, "timeout", 15, "timeout for resolving a individual name, in seconds") rootCmd.PersistentFlags().IntVar(&GC.IterationTimeout, "iteration-timeout", 4, "timeout for a single iterative step in an iterative query, in seconds. Only applicable with --iterative") @@ -174,8 +183,8 @@ func init() { rootCmd.PersistentFlags().BoolVar(&GC.Dnssec, "dnssec", false, "Requests DNSSEC records by setting the DNSSEC OK (DO) bit") rootCmd.PersistentFlags().BoolVar(&GC.UseNSID, "nsid", false, "Request NSID.") - rootCmd.PersistentFlags().Bool("ipv4-lookup", false, "Perform an IPv4 Lookup in modules") - rootCmd.PersistentFlags().Bool("ipv6-lookup", false, "Perform an IPv6 Lookup in modules") + rootCmd.PersistentFlags().Bool("ipv4-lookup", false, "Perform an IPv4 Lookup (requests A records) in modules") + rootCmd.PersistentFlags().Bool("ipv6-lookup", false, "Perform an IPv6 Lookup (requests AAAA recoreds) in modules") rootCmd.PersistentFlags().StringVar(&GC.BlacklistFilePath, "blacklist-file", "", "blacklist file for servers to exclude from lookups") } diff --git a/src/cli/config_validation.go b/src/cli/config_validation.go index 1fc1652c..e1964eca 100644 --- a/src/cli/config_validation.go +++ b/src/cli/config_validation.go @@ -39,6 +39,7 @@ func populateNetworkingConfig(gc *CLIConf) error { return errors.Wrap(err, "client subnet did not pass validation") } + // local address - the user can enter both IPv4 and IPv6 addresses. We'll differentiate them later if GC.LocalAddrString != "" { for _, la := range strings.Split(GC.LocalAddrString, ",") { ip := net.ParseIP(la) @@ -51,6 +52,7 @@ func populateNetworkingConfig(gc *CLIConf) error { gc.LocalAddrSpecified = true } + // local interface - same as local addresses, an interface could have both IPv4 and IPv6 addresses, we'll differentiate them later if gc.LocalIfaceString != "" { li, err := net.InterfaceByName(gc.LocalIfaceString) if err != nil { @@ -114,7 +116,7 @@ func parseNameServers(gc *CLIConf) error { if gc.NameServerMode { log.Fatal("name servers cannot be specified on command line in --name-server-mode") } - var ns []string + var nses []string if (gc.NameServersString)[0] == '@' { filepath := (gc.NameServersString)[1:] f, err := os.ReadFile(filepath) @@ -124,11 +126,16 @@ func parseNameServers(gc *CLIConf) error { if len(f) == 0 { log.Fatalf("Empty file (%s)", filepath) } - ns = strings.Split(strings.Trim(string(f), "\n"), "\n") + nses = strings.Split(strings.Trim(string(f), "\n"), "\n") } else { - ns = strings.Split(gc.NameServersString, ",") + nses = strings.Split(gc.NameServersString, ",") + trimmedNSes := make([]string, 0, len(nses)) + for _, ns := range nses { + trimmedNSes = append(trimmedNSes, strings.TrimSpace(ns)) + } + nses = trimmedNSes } - gc.NameServers = ns + gc.NameServers = nses } return nil } diff --git a/src/cli/config_validation_test.go b/src/cli/config_validation_test.go index efda52ca..334a1593 100644 --- a/src/cli/config_validation_test.go +++ b/src/cli/config_validation_test.go @@ -22,15 +22,17 @@ import ( func TestValidateNetworkingConfig(t *testing.T) { t.Run("LocalAddr and LocalInterface both specified", func(t *testing.T) { gc := &CLIConf{ - LocalAddrString: "1.1.1.1", - LocalIfaceString: "eth0", + LocalAddrString: "1.1.1.1", + LocalIfaceString: "eth0", + IPv4TransportOnly: true, } err := populateNetworkingConfig(gc) require.NotNil(t, err, "Expected an error but got nil") }) t.Run("Using invalid interface", func(t *testing.T) { gc := &CLIConf{ - LocalIfaceString: "invalid_interface", + LocalIfaceString: "invalid_interface", + IPv4TransportOnly: true, } err := populateNetworkingConfig(gc) require.NotNil(t, err, "Expected an error but got nil") @@ -38,6 +40,7 @@ func TestValidateNetworkingConfig(t *testing.T) { t.Run("Using nameserver with port", func(t *testing.T) { gc := &CLIConf{ NameServersString: "127.0.0.1:53", + IPv4TransportOnly: true, } err := populateNetworkingConfig(gc) require.Nil(t, err, "Expected no error but got %v", err) diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index f0a4c415..51b1bd7f 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -39,7 +39,9 @@ import ( ) const ( - googleDNSResolverAddr = "8.8.8.8:53" + googleDNSResolverAddr = "8.8.8.8:53" + googleDNSResolverAddrV6 = "[2001:4860:4860::8888]:53" + loopbackIPv4Addr = "127.0.0.1" ) type routineMetadata struct { @@ -160,17 +162,9 @@ func populateCLIConfig(gc *CLIConf, flags *pflag.FlagSet) *CLIConf { return gc } -func populateResolverConfig(gc *CLIConf, flags *pflag.FlagSet) *zdns.ResolverConfig { +func populateResolverConfig(gc *CLIConf) *zdns.ResolverConfig { config := zdns.NewResolverConfig() - useIPv4, err := flags.GetBool("ipv4-lookup") - if err != nil { - log.Fatal("Unable to parse ipv4 flag: ", err) - } - useIPv6, err := flags.GetBool("ipv6-lookup") - if err != nil { - log.Fatal("Unable to parse ipv6 flag: ", err) - } - config.IPVersionMode = zdns.GetIPVersionMode(useIPv4, useIPv6) + config.TransportMode = zdns.GetTransportMode(gc.UDPOnly, gc.TCPOnly) config.Timeout = time.Second * time.Duration(gc.Timeout) @@ -197,18 +191,52 @@ func populateResolverConfig(gc *CLIConf, flags *pflag.FlagSet) *zdns.ResolverCon if gc.BlacklistFilePath != "" { config.Blacklist = blacklist.New() - if err = config.Blacklist.ParseFromFile(gc.BlacklistFilePath); err != nil { + if err := config.Blacklist.ParseFromFile(gc.BlacklistFilePath); err != nil { log.Fatal("unable to parse blacklist file: ", err) } } // This must occur after setting the DNSConfigFilePath above, so that ZDNS knows where to fetch the DNS Config + config, err := populateIPTransportMode(gc, config) + if err != nil { + log.Fatal("could not populate IP transport mode: ", err) + } + // This is used in extractAuthorities where we need to know whether to request A or AAAA records to continue iteration + // Must be set after populating IPTransportMode + if config.IPVersionMode == zdns.IPv4Only { + config.IterationIPPreference = zdns.PreferIPv4 + } else if config.IPVersionMode == zdns.IPv6Only { + config.IterationIPPreference = zdns.PreferIPv6 + } else if config.IPVersionMode == zdns.IPv4OrIPv6 && !gc.PreferIPv4Iteration && !gc.PreferIPv6Iteration { + // need to specify some type of preference, we'll default to IPv4 and inform the user + log.Info("No iteration IP preference specified, defaulting to IPv4 preferred. See --prefer-ipv4-iteration and --prefer-ipv6-iteration for more info") + config.IterationIPPreference = zdns.PreferIPv4 + } else if config.IPVersionMode == zdns.IPv4OrIPv6 && gc.PreferIPv4Iteration && gc.PreferIPv6Iteration { + log.Fatal("Cannot specify both --prefer-ipv4-iteration and --prefer-ipv6-iteration") + } else { + config.IterationIPPreference = zdns.GetIterationIPPreference(gc.PreferIPv4Iteration, gc.PreferIPv6Iteration) + } + // This must occur after setting the DNSConfigFilePath above, so that ZDNS knows where to fetch the DNS Config config, err = populateNameServers(gc, config) if err != nil { log.Fatal("could not populate name servers: ", err) } // User/OS defaults could contain duplicates, remove - config.ExternalNameServers = util.RemoveDuplicates(config.ExternalNameServers) - config.RootNameServers = util.RemoveDuplicates(config.RootNameServers) + config.ExternalNameServersV4 = util.RemoveDuplicates(config.ExternalNameServersV4) + config.RootNameServersV4 = util.RemoveDuplicates(config.RootNameServersV4) + config.ExternalNameServersV6 = util.RemoveDuplicates(config.ExternalNameServersV6) + config.RootNameServersV6 = util.RemoveDuplicates(config.RootNameServersV6) + + if config.IPVersionMode == zdns.IPv4Only { + // Drop any IPv6 nameservers + config.ExternalNameServersV6 = []string{} + config.RootNameServersV6 = []string{} + } + if config.IPVersionMode == zdns.IPv6Only { + // Drop any IPv4 nameservers + config.ExternalNameServersV4 = []string{} + config.RootNameServersV4 = []string{} + } + config, err = populateLocalAddresses(gc, config) if err != nil { log.Fatal("could not populate local addresses: ", err) @@ -216,6 +244,75 @@ func populateResolverConfig(gc *CLIConf, flags *pflag.FlagSet) *zdns.ResolverCon return config } +// populateIPTransportMode populates the IPTransportMode field of the ResolverConfig +// If user sets --4 (IPv4 Only) or --6 (IPv6 Only), we'll set the IPVersionMode to IPv4Only or IPv6Only, respectively. +// Otherwise, we need to determine the IPVersionMode based on either the OS' default resolver(s) or the user's provided name servers. +// Note: populateNameServers must be called before this function to ensure the nameservers are populated. +func populateIPTransportMode(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.ResolverConfig, error) { + if gc.IPv4TransportOnly && gc.IPv6TransportOnly { + return nil, errors.New("only one of --4 and --6 allowed") + } + if gc.IPv4TransportOnly { + config.IPVersionMode = zdns.IPv4Only + return config, nil + } + if gc.IPv6TransportOnly { + config.IPVersionMode = zdns.IPv6Only + return config, nil + } + nameServersSupportIPv4 := false + nameServersSupportIPv6 := false + // Check if user provided nameservers + if len(gc.NameServers) != 0 { + // Check that the nameservers have a port and append one if necessary + portValidatedNSs := make([]string, 0, len(gc.NameServers)) + // check that the nameservers have a port and append one if necessary + for _, ns := range gc.NameServers { + portNS, err := util.AddDefaultPortToDNSServerName(ns) + if err != nil { + return nil, fmt.Errorf("could not parse name server: %s. Correct IPv4 format: 1.1.1.1:53 or IPv6 format: [::1]:53", ns) + } + portValidatedNSs = append(portValidatedNSs, portNS) + } + v4NameServers, v6NameServers, err := util.SplitIPv4AndIPv6Addrs(portValidatedNSs) + if err != nil { + return nil, errors.Wrap(err, "could not split IPv4 and IPv6 addresses for nameservers") + } + if len(v4NameServers) != 0 { + nameServersSupportIPv4 = true + } + if len(v6NameServers) != 0 { + nameServersSupportIPv6 = true + } + } else { + // User did not provide nameservers, check the OS' default resolver(s) + v4NameServers, v6NameServers, err := zdns.GetDNSServers(config.DNSConfigFilePath) + if err != nil { + log.Warn("Unable to parse resolvers file to determine if IPv4 or IPv6 is supported. Defaulting to IPv4") + config.IPVersionMode = zdns.IPv4Only + return config, nil + } + if len(v4NameServers) != 0 { + nameServersSupportIPv4 = true + } + if len(v6NameServers) != 0 { + nameServersSupportIPv6 = true + } + } + if nameServersSupportIPv4 && nameServersSupportIPv6 { + config.IPVersionMode = zdns.IPv4OrIPv6 + return config, nil + } else if nameServersSupportIPv4 { + config.IPVersionMode = zdns.IPv4Only + return config, nil + } else if nameServersSupportIPv6 { + config.IPVersionMode = zdns.IPv6Only + return config, nil + } else { + return nil, errors.New("no nameservers found with OS defaults. Please specify desired nameservers with --name-servers") + } +} + func populateNameServers(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.ResolverConfig, error) { // Nameservers are populated in this order: // 1. If user provided nameservers, use those @@ -225,6 +322,7 @@ func populateNameServers(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Resolv // Additionally, both Root and External nameservers must be populated, since the Resolver doesn't know we'll only // be performing either iterative or recursive lookups, not both. + // IPv4 Name Servers/Local Address only needs to be populated if we're doing IPv4 lookups, same for IPv6 if len(gc.NameServers) != 0 { // User provided name servers, use them. // Check that the nameservers have a port and append one if necessary @@ -233,44 +331,50 @@ func populateNameServers(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Resolv for _, ns := range gc.NameServers { portNS, err := util.AddDefaultPortToDNSServerName(ns) if err != nil { - // TODO Update error msg when we add IPv6 - return nil, fmt.Errorf("could not parse name server: %s. Correct IPv4 format: 1.1.1.1:53", ns) + return nil, fmt.Errorf("could not parse name server: %s. Correct IPv4 format: 1.1.1.1:53 or IPv6 format: [::1]:53", ns) } portValidatedNSs = append(portValidatedNSs, portNS) } - config.ExternalNameServers = portValidatedNSs - config.RootNameServers = portValidatedNSs + v4NameServers, v6NameServers, err := util.SplitIPv4AndIPv6Addrs(portValidatedNSs) + if err != nil { + return nil, errors.Wrap(err, "could not split IPv4 and IPv6 addresses for nameservers") + } + // The resolver will ignore IPv6 nameservers if we're doing IPv4 only lookups, and vice versa so this is fine + config.ExternalNameServersV4 = v4NameServers + config.RootNameServersV4 = v4NameServers + config.ExternalNameServersV6 = v6NameServers + config.RootNameServersV6 = v6NameServers return config, nil } // User did not provide nameservers if !gc.IterativeResolution { // Try to get the OS' default recursive resolver nameservers - ns, err := zdns.GetDNSServers(config.DNSConfigFilePath) + v4NameServers, v6NameServers, err := zdns.GetDNSServers(config.DNSConfigFilePath) if err != nil { - ns = util.GetDefaultResolvers() - log.Warn("Unable to parse resolvers file. Using ZDNS defaults: ", strings.Join(ns, ", ")) + v4NameServers, v6NameServers = zdns.DefaultExternalResolversV4, zdns.DefaultExternalResolversV6 + log.Warn("Unable to parse resolvers file. Using ZDNS defaults: ", strings.Join(util.Concat(v4NameServers, v6NameServers), ", ")) } - // TODO remove when we add IPv6 support, without this if the user's OS defaults contain both IPv4/6 - // It'll fail validation since we can't currently handle those - ipv4NameServers := make([]string, 0, len(ns)) - for _, addr := range ns { - ip, _, err := util.SplitHostPort(addr) - if err != nil { - return nil, errors.Wrapf(err, "could not split host and port for nameserver: %s", addr) + if config.IPVersionMode != zdns.IPv6Only { + if len(v4NameServers) == 0 { + return nil, errors.New("no IPv4 nameservers found. Please specify desired nameservers with --name-servers") } - if ip.To4() == nil { - log.Infof("Ignoring non-IPv4 nameserver: %s", ip.String()) - continue + config.ExternalNameServersV4 = v4NameServers + config.RootNameServersV4 = v4NameServers + } + if config.IPVersionMode != zdns.IPv4Only { + if len(v6NameServers) == 0 { + return nil, errors.New("no IPv6 nameservers found. Please specify desired nameservers with --name-servers") } - ipv4NameServers = append(ipv4NameServers, addr) + config.ExternalNameServersV6 = v6NameServers + config.RootNameServersV6 = v6NameServers } - config.ExternalNameServers = ipv4NameServers - config.RootNameServers = ipv4NameServers return config, nil } // User did not provide nameservers and we're doing iterative resolution, use ZDNS defaults - config.ExternalNameServers = zdns.RootServersV4[:] - config.RootNameServers = zdns.RootServersV4[:] + config.ExternalNameServersV4 = zdns.RootServersV4[:] + config.RootNameServersV4 = zdns.RootServersV4[:] + config.ExternalNameServersV6 = zdns.RootServersV6[:] + config.RootNameServersV6 = zdns.RootServersV6[:] return config, nil } @@ -279,28 +383,33 @@ func populateLocalAddresses(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Res // 1. If user provided local addresses, use those // 2. If the config's nameservers are loopback, use the local loopback address // 3. Otherwise, try to connect to Google's recursive resolver and take the IP address we use for the connection + + // IPv4 local addresses are required for IPv4 lookups, same for IPv6 if len(gc.LocalAddrs) != 0 { // if user provided a local address(es), that takes precedent - // TODO remove when we add IPv6 support, without this if the user provides --local-interface with both IPv4/6 - // It'll fail validation since we can't currently handle those - ipv4LocalAddrs := make([]net.IP, 0, len(gc.LocalAddrs)) + config.LocalAddrsV4, config.LocalAddrsV6 = []net.IP{}, []net.IP{} for _, addr := range gc.LocalAddrs { - if addr.To4() == nil { - log.Infof("Ignoring non-IPv4 local address: %s", addr.String()) - continue + if addr == nil { + return nil, errors.New("invalid nil local address") + } + if addr.To4() != nil { + config.LocalAddrsV4 = append(config.LocalAddrsV4, addr) + } else if util.IsIPv6(&addr) { + config.LocalAddrsV6 = append(config.LocalAddrsV6, addr) + } else { + return nil, fmt.Errorf("invalid local address: %s", addr.String()) } - ipv4LocalAddrs = append(ipv4LocalAddrs, addr) } - config.LocalAddrs = ipv4LocalAddrs return config, nil } // if the nameservers are loopback, use the loopback address - if len(config.ExternalNameServers) == 0 { + allNameServers := util.Concat(config.ExternalNameServersV4, config.ExternalNameServersV6, config.RootNameServersV4, config.RootNameServersV6) + if len(allNameServers) == 0 { // this shouldn't happen return nil, errors.New("name servers must be set before populating local addresses") } anyNameServersLoopack := false - for _, ns := range util.Concat(config.ExternalNameServers, config.RootNameServers) { + for _, ns := range allNameServers { ip, _, err := util.SplitHostPort(ns) if err != nil { return nil, errors.Wrapf(err, "could not split host and port for nameserver: %s", ns) @@ -310,19 +419,34 @@ func populateLocalAddresses(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Res break } } + if anyNameServersLoopack { // set local address so name servers are reachable - config.LocalAddrs = []net.IP{net.ParseIP(zdns.LoopbackAddrString)} + config.LocalAddrsV4 = []net.IP{net.ParseIP(loopbackIPv4Addr)} + // loopback nameservers not supported for IPv6, we'll let Resolver validation take care of this } else { // localAddr not set, so we need to find the default IP address - conn, err := net.Dial("udp", googleDNSResolverAddr) - if err != nil { - return nil, fmt.Errorf("unable to find default IP address to open socket: %w", err) + if config.IPVersionMode != zdns.IPv6Only { + conn, err := net.Dial("udp", googleDNSResolverAddr) + if err != nil { + return nil, fmt.Errorf("unable to find default IP address to open socket: %w", err) + } + config.LocalAddrsV4 = []net.IP{conn.LocalAddr().(*net.UDPAddr).IP} + // cleanup socket + if err = conn.Close(); err != nil { + log.Error("unable to close test connection to Google public DNS: ", err) + } } - config.LocalAddrs = []net.IP{conn.LocalAddr().(*net.UDPAddr).IP} - // cleanup socket - if err = conn.Close(); err != nil { - log.Error("unable to close test connection to Google public DNS: ", err) + if config.IPVersionMode != zdns.IPv4Only { + conn, err := net.Dial("udp", googleDNSResolverAddrV6) + if err != nil { + return nil, fmt.Errorf("unable to find default IP address to open socket: %w", err) + } + config.LocalAddrsV6 = []net.IP{conn.LocalAddr().(*net.UDPAddr).IP} + // cleanup socket + if err = conn.Close(); err != nil { + log.Error("unable to close test connection to Google public IPv6 DNS: ", err) + } } } return config, nil @@ -330,7 +454,7 @@ func populateLocalAddresses(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Res func Run(gc CLIConf, flags *pflag.FlagSet) { gc = *populateCLIConfig(&gc, flags) - resolverConfig := populateResolverConfig(&gc, flags) + resolverConfig := populateResolverConfig(&gc) // Log any information about the resolver configuration, according to log level resolverConfig.PrintInfo() err := resolverConfig.Validate() diff --git a/src/internal/util/util.go b/src/internal/util/util.go index 616e2248..cbafcada 100644 --- a/src/internal/util/util.go +++ b/src/internal/util/util.go @@ -74,7 +74,30 @@ func SplitHostPort(inaddr string) (net.IP, int, error) { } return ip, portInt, nil +} +// SplitIPv4AndIPv6Addrs splits a list of IP addresses (either with port attached or not) into IPv4 and IPv6 addresses. +// Returns a slice of IPv4/IPv6 addresses that are guaranteed to be valid. If the port was attached, it'll be included. +func SplitIPv4AndIPv6Addrs(addrs []string) (ipv4 []string, ipv6 []string, err error) { + for _, addr := range addrs { + ip, _, err := SplitHostPort(addr) + if err != nil { + // addr may be an IP without a port + ip = net.ParseIP(addr) + } + if ip == nil { + return nil, nil, fmt.Errorf("invalid IP address: %s", addr) + } + // ip is valid, check if it's IPv4 or IPv6 + if ip.To4() != nil { + ipv4 = append(ipv4, addr) + } else if ip.To16() != nil { + ipv6 = append(ipv6, addr) + } else { + return nil, nil, fmt.Errorf("invalid IP address: %s", addr) + } + } + return ipv4, ipv6, nil } // Reference: https://github.com/carolynvs/stingoftheviper/blob/main/main.go @@ -103,11 +126,6 @@ func BindFlags(cmd *cobra.Command, v *viper.Viper, envPrefix string) { }) } -// getDefaultResolvers returns a slice of default DNS resolvers to be used when no system resolvers could be discovered. -func GetDefaultResolvers() []string { - return []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53", "1.0.0.1:53"} -} - // IsStringValidDomainName checks if the given string is a valid domain name using regex func IsStringValidDomainName(domain string) bool { var domainRegex = regexp.MustCompile(`^(?i)[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)*\.[a-z]{2,}$`) diff --git a/src/modules/alookup/a_lookup.go b/src/modules/alookup/a_lookup.go index 44d557f8..41dbc77d 100644 --- a/src/modules/alookup/a_lookup.go +++ b/src/modules/alookup/a_lookup.go @@ -58,7 +58,7 @@ func (aMod *ALookupModule) Init(ipv4Lookup bool, ipv6Lookup bool) { } func (aMod *ALookupModule) Lookup(r *zdns.Resolver, lookupName, nameServer string) (interface{}, zdns.Trace, zdns.Status, error) { - ipResult, trace, status, err := r.DoTargetedLookup(lookupName, nameServer, zdns.GetIPVersionMode(aMod.IPv4Lookup, aMod.IPv6Lookup), aMod.baseModule.IsIterative) + ipResult, trace, status, err := r.DoTargetedLookup(lookupName, nameServer, aMod.baseModule.IsIterative, aMod.IPv4Lookup, aMod.IPv6Lookup) return ipResult, trace, status, err } diff --git a/src/modules/axfr/axfr_test.go b/src/modules/axfr/axfr_test.go index 394f8c96..59263b27 100644 --- a/src/modules/axfr/axfr_test.go +++ b/src/modules/axfr/axfr_test.go @@ -93,9 +93,10 @@ func InitTest() (*AxfrLookupModule, *zdns.Resolver) { cc := new(cli.CLIConf) rc := new(zdns.ResolverConfig) - rc.RootNameServers = []string{"127.0.0.53:53"} - rc.ExternalNameServers = []string{"127.0.0.53:53"} - rc.LocalAddrs = []net.IP{net.ParseIP("127.0.0.1")} + rc.RootNameServersV4 = []string{"127.0.0.53:53"} + rc.ExternalNameServersV4 = []string{"127.0.0.53:53"} + rc.LocalAddrsV4 = []net.IP{net.ParseIP("127.0.0.1")} + rc.IPVersionMode = zdns.IPv4Only flagSet := new(pflag.FlagSet) flagSet.Bool("ipv4-lookup", false, "Use IPv4") diff --git a/src/modules/bindversion/bindversion_test.go b/src/modules/bindversion/bindversion_test.go index c7508a12..988efd5f 100644 --- a/src/modules/bindversion/bindversion_test.go +++ b/src/modules/bindversion/bindversion_test.go @@ -47,10 +47,11 @@ func (ml MockLookup) DoSingleDstServerLookup(r *zdns.Resolver, question zdns.Que func InitTest(t *testing.T) *zdns.Resolver { mockResults = make(map[string]*zdns.SingleQueryResult) rc := zdns.ResolverConfig{ - ExternalNameServers: []string{"1.1.1.1:53"}, - RootNameServers: []string{"1.1.1.1:53"}, - LocalAddrs: []net.IP{net.ParseIP("192.168.1.1")}, - LookupClient: MockLookup{}} + ExternalNameServersV4: []string{"1.1.1.1:53"}, + RootNameServersV4: []string{"1.1.1.1:53"}, + LocalAddrsV4: []net.IP{net.ParseIP("192.168.1.1")}, + IPVersionMode: zdns.IPv4Only, + LookupClient: MockLookup{}} r, err := zdns.InitResolver(&rc) assert.NilError(t, err) diff --git a/src/modules/dmarc/dmarc_test.go b/src/modules/dmarc/dmarc_test.go index 9630b612..b217345d 100644 --- a/src/modules/dmarc/dmarc_test.go +++ b/src/modules/dmarc/dmarc_test.go @@ -47,10 +47,11 @@ func (ml MockLookup) DoSingleDstServerLookup(r *zdns.Resolver, question zdns.Que func InitTest(t *testing.T) *zdns.Resolver { mockResults = make(map[string]*zdns.SingleQueryResult) rc := zdns.ResolverConfig{ - ExternalNameServers: []string{"127.0.0.1:53"}, - RootNameServers: []string{"127.0.0.53:53"}, - LocalAddrs: []net.IP{net.ParseIP("127.0.0.1")}, - LookupClient: MockLookup{}} + ExternalNameServersV4: []string{"127.0.0.1:53"}, + RootNameServersV4: []string{"127.0.0.53:53"}, + LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, + IPVersionMode: zdns.IPv4Only, + LookupClient: MockLookup{}} r, err := zdns.InitResolver(&rc) assert.NilError(t, err) diff --git a/src/modules/mxlookup/mx_lookup.go b/src/modules/mxlookup/mx_lookup.go index 75a255b6..7a251df3 100644 --- a/src/modules/mxlookup/mx_lookup.go +++ b/src/modules/mxlookup/mx_lookup.go @@ -110,7 +110,7 @@ func (mxMod *MXLookupModule) lookupIPs(r *zdns.Resolver, name, nameServer string return res.(CachedAddresses), zdns.Trace{} } retv := CachedAddresses{} - result, trace, status, _ := r.DoTargetedLookup(name, nameServer, ipMode, mxMod.IsIterative) + result, trace, status, _ := r.DoTargetedLookup(name, nameServer, mxMod.IsIterative, mxMod.IPv4Lookup, mxMod.IPv6Lookup) if status == zdns.StatusNoError && result != nil { retv.IPv4Addresses = result.IPv4Addresses retv.IPv6Addresses = result.IPv6Addresses diff --git a/src/modules/nslookup/ns_lookup.go b/src/modules/nslookup/ns_lookup.go index b79c2d93..71c2b01d 100644 --- a/src/modules/nslookup/ns_lookup.go +++ b/src/modules/nslookup/ns_lookup.go @@ -73,7 +73,7 @@ func (nsMod *NSLookupModule) Lookup(r *zdns.Resolver, lookupName string, nameSer log.Warn("iterative lookup requested with lookupName server, ignoring lookupName server") } - res, trace, status, err := r.DoNSLookup(lookupName, nameServer, nsMod.IsIterative) + res, trace, status, err := r.DoNSLookup(lookupName, nameServer, nsMod.IsIterative, nsMod.IPv4Lookup, nsMod.IPv6Lookup) if trace == nil { trace = zdns.Trace{} } diff --git a/src/modules/spf/spf_test.go b/src/modules/spf/spf_test.go index bc215c6e..55c0ed79 100644 --- a/src/modules/spf/spf_test.go +++ b/src/modules/spf/spf_test.go @@ -48,10 +48,11 @@ func InitTest(t *testing.T) *zdns.Resolver { mockResults = make(map[string]*zdns.SingleQueryResult) queries = make([]QueryRecord, 0) rc := zdns.ResolverConfig{ - ExternalNameServers: []string{"127.0.0.1:53"}, - RootNameServers: []string{"127.0.0.53:53"}, - LocalAddrs: []net.IP{net.ParseIP("127.0.0.1")}, - LookupClient: MockLookup{}} + ExternalNameServersV4: []string{"127.0.0.1:53"}, + RootNameServersV4: []string{"127.0.0.53:53"}, + LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, + IPVersionMode: zdns.IPv4Only, + LookupClient: MockLookup{}} r, err := zdns.InitResolver(&rc) assert.NilError(t, err) diff --git a/src/zdns/alookup.go b/src/zdns/alookup.go index 68900cf9..ddadc41c 100644 --- a/src/zdns/alookup.go +++ b/src/zdns/alookup.go @@ -24,9 +24,7 @@ import ( // DoTargetedLookup performs a lookup of the given domain name against the given nameserver, looking up both IPv4 and IPv6 addresses // Will follow CNAME records as well as A/AAAA records to get IP addresses -func (r *Resolver) DoTargetedLookup(name, nameServer string, ipMode IPVersionMode, isIterative bool) (*IPResult, Trace, Status, error) { - lookupIPv4 := ipMode == IPv4Only || ipMode == IPv4OrIPv6 - lookupIPv6 := ipMode == IPv6Only || ipMode == IPv4OrIPv6 +func (r *Resolver) DoTargetedLookup(name, nameServer string, isIterative, lookupA, lookupAAAA bool) (*IPResult, Trace, Status, error) { name = strings.ToLower(name) res := IPResult{} candidateSet := map[string][]Answer{} @@ -39,7 +37,7 @@ func (r *Resolver) DoTargetedLookup(name, nameServer string, ipMode IPVersionMod var ipv4status Status var ipv6status Status - if lookupIPv4 { + if lookupA { ipv4, ipv4Trace, ipv4status, _ = recursiveIPLookup(r, name, nameServer, dns.TypeA, candidateSet, cnameSet, dnameSet, name, 0, isIterative) if len(ipv4) > 0 { ipv4 = Unique(ipv4) @@ -49,8 +47,7 @@ func (r *Resolver) DoTargetedLookup(name, nameServer string, ipMode IPVersionMod } candidateSet = map[string][]Answer{} cnameSet = map[string][]Answer{} - dnameSet = map[string][]Answer{} - if lookupIPv6 { + if lookupAAAA { ipv6, ipv6Trace, ipv6status, _ = recursiveIPLookup(r, name, nameServer, dns.TypeAAAA, candidateSet, cnameSet, dnameSet, name, 0, isIterative) if len(ipv6) > 0 { ipv6 = Unique(ipv6) @@ -64,9 +61,9 @@ func (r *Resolver) DoTargetedLookup(name, nameServer string, ipMode IPVersionMod // In case we get no IPs and a non-NOERROR status from either // IPv4 or IPv6 lookup, we return that status. if len(res.IPv4Addresses) == 0 && len(res.IPv6Addresses) == 0 { - if lookupIPv4 && !SafeStatus(ipv4status) { + if lookupA && !SafeStatus(ipv4status) { return nil, combinedTrace, ipv4status, nil - } else if lookupIPv6 && !SafeStatus(ipv6status) { + } else if lookupAAAA && !SafeStatus(ipv6status) { return nil, combinedTrace, ipv6status, nil } else { return &res, combinedTrace, StatusNoError, nil diff --git a/src/zdns/conf.go b/src/zdns/conf.go index 3264fb3b..b9fbea1e 100644 --- a/src/zdns/conf.go +++ b/src/zdns/conf.go @@ -51,9 +51,10 @@ const ( StatusTimeout Status = "TIMEOUT" StatusIterTimeout Status = "ITERATIVE_TIMEOUT" StatusNoAuth Status = "NOAUTH" + StatusNoNeededGlue Status = "NONEEDEDGLUE" // When a nameserver is authoritative for itself and the parent nameserver doesn't provide the glue to look it up ) -var RootServersV4 = [...]string{ +var RootServersV4 = []string{ "198.41.0.4:53", // A "170.247.170.2:53", // B - Changed several times, this is current as of July '24 "192.33.4.12:53", // C @@ -68,7 +69,7 @@ var RootServersV4 = [...]string{ "199.7.83.42:53", // L "202.12.27.33:53"} // M -var RootServersV6 = [...]string{ +var RootServersV6 = []string{ "[2001:503:ba3e::2:30]:53", // A "[2801:1b8:10::b]:53", // B "[2001:500:2::c]:53", // C @@ -83,3 +84,17 @@ var RootServersV6 = [...]string{ "[2001:500:9f::42]:53", // L "[2001:dc3::35]:53", // M } + +var DefaultExternalResolversV4 = []string{ + "8.8.8.8:53", + "8.8.4.4:53", + "1.1.1.1:53", + "1.0.0.1:53", +} + +var DefaultExternalResolversV6 = []string{ + "[2001:4860:4860::8888]:53", + "[2001:4860:4860::8844]:53", + "[2606:4700:4700::1111]:53", + "[2606:4700:4700::1001]:53", +} diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index c503d67c..75e6366c 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -27,11 +27,11 @@ import ( "github.com/zmap/zdns/src/internal/util" ) -// GetDNSServers returns a list of DNS servers from a file, or an error if one occurs -func GetDNSServers(path string) ([]string, error) { +// GetDNSServers returns a list of IPv4, IPv6 DNS servers from a file, or an error if one occurs +func GetDNSServers(path string) (ipv4, ipv6 []string, err error) { c, err := dns.ClientConfigFromFile(path) if err != nil { - return []string{}, fmt.Errorf("error reading DNS config file: %w", err) + return []string{}, []string{}, fmt.Errorf("error reading DNS config file (%s): %w", path, err) } servers := make([]string, 0, len(c.Servers)) for _, s := range c.Servers { @@ -41,7 +41,22 @@ func GetDNSServers(path string) ([]string, error) { full := strings.Join([]string{s, c.Port}, ":") servers = append(servers, full) } - return servers, nil + ipv4 = make([]string, 0, len(servers)) + ipv6 = make([]string, 0, len(servers)) + for _, s := range servers { + ip, _, err := util.SplitHostPort(s) + if err != nil { + return []string{}, []string{}, fmt.Errorf("could not parse IP address (%s) from file: %w", s, err) + } + if ip.To4() != nil { + ipv4 = append(ipv4, s) + } else if util.IsIPv6(&ip) { + ipv6 = append(ipv6, s) + } else { + return []string{}, []string{}, fmt.Errorf("could not parse IP address (%s) from file: %s", s, path) + } + } + return ipv4, ipv6, nil } // Lookup client interface for help in mocking @@ -56,6 +71,7 @@ func (lc LookupClient) DoSingleDstServerLookup(r *Resolver, q Question, nameServ } func (r *Resolver) doSingleDstServerLookup(q Question, nameServer string, isIterative bool) (*SingleQueryResult, Trace, Status, error) { + var err error // Check that nameserver isn't blacklisted nameServerIPString, _, err := net.SplitHostPort(nameServer) if err != nil { @@ -266,7 +282,7 @@ func (r *Resolver) LookupAllNameservers(q *Question, nameServer string) (*Combin var curServer string // Lookup both ipv4 and ipv6 addresses of nameservers. - nsResults, nsTrace, nsStatus, nsError := r.DoNSLookup(q.Name, nameServer, false) + nsResults, nsTrace, nsStatus, nsError := r.DoNSLookup(q.Name, nameServer, false, true, true) // Terminate early if nameserver lookup also failed if nsStatus != StatusNoError { @@ -287,7 +303,12 @@ func (r *Resolver) LookupAllNameservers(q *Question, nameServer string) (*Combin ips := util.Concat(nserver.IPv4Addresses, nserver.IPv6Addresses) for _, ip := range ips { curServer = net.JoinHostPort(ip, "53") - res, trace, status, _ := r.ExternalLookup(q, curServer) + res, trace, status, err := r.ExternalLookup(q, curServer) + if err != nil { + // log and move on + log.Errorf("lookup for domain %s to nameserver %s failed with error %s. Continueing to next nameserver", q.Name, curServer, err) + continue + } fullTrace = append(fullTrace, trace...) extendedResult := ExtendedResult{ @@ -406,14 +427,37 @@ func (r *Resolver) cachedRetryingLookup(ctx context.Context, q Question, nameSer // retryingLookup wraps around wireLookup to perform a DNS lookup with retries // Returns the result, status, number of tries, and error func (r *Resolver) retryingLookup(ctx context.Context, q Question, nameServer string, recursive bool) (SingleQueryResult, Status, int, error) { + // nameserver is required + if nameServer == "" { + return SingleQueryResult{}, StatusIllegalInput, 0, errors.New("no nameserver specified") + } + nameServerIP, _, err := util.SplitHostPort(nameServer) + if err != nil { + return SingleQueryResult{}, StatusError, 0, errors.Wrapf(err, "could not split nameserver %s to get IP", nameServer) + } + var connInfo *ConnectionInfo + if nameServerIP.To4() != nil { + connInfo = r.connInfoIPv4 + } else if nameServerIP.To16() != nil { + connInfo = r.connInfoIPv6 + } else { + return SingleQueryResult{}, StatusError, 0, fmt.Errorf("could not determine IP version of nameserver: %s", nameServer) + } + // check that our connection info is valid + if connInfo == nil { + return SingleQueryResult{}, StatusError, 0, fmt.Errorf("no connection info for nameserver: %s", nameServer) + } + // check loopback consistency + if nameServerIP.IsLoopback() != connInfo.localAddr.IsLoopback() { + return SingleQueryResult{}, StatusIllegalInput, 0, fmt.Errorf("nameserver %s must be reachable from the local address %s, ie. both must be loopback or not loopback", nameServer, connInfo.localAddr.String()) + } r.verboseLog(1, "****WIRE LOOKUP*** ", dns.TypeToString[q.Type], " ", q.Name, " ", nameServer) for i := 0; i <= r.retries; i++ { // check context before going into wireLookup if util.HasCtxExpired(&ctx) { - var r SingleQueryResult - return r, StatusTimeout, i + 1, nil + return SingleQueryResult{}, StatusTimeout, i + 1, nil } - result, status, err := wireLookup(ctx, r.udpClient, r.tcpClient, r.conn, q, nameServer, recursive, r.ednsOptions, r.dnsSecEnabled, r.checkingDisabledBit) + result, status, err := wireLookup(ctx, connInfo.udpClient, connInfo.tcpClient, connInfo.conn, q, nameServer, recursive, r.ednsOptions, r.dnsSecEnabled, r.checkingDisabledBit) if status != StatusTimeout || i == r.retries { return result, status, i + 1, err } @@ -528,7 +572,6 @@ func (r *Resolver) iterateOnAuthorities(ctx context.Context, q Question, depth i if nsStatus != StatusNoError { var err error newStatus, err := handleStatus(nsStatus, err) - // default case we continue if err == nil { if i+1 == len(result.Authorities) { r.verboseLog((depth + 2), "--> Auth find Failed. Unknown error. No more authorities to try, terminating: ", nsStatus) @@ -552,8 +595,11 @@ func (r *Resolver) iterateOnAuthorities(ctx context.Context, q Question, depth i } } iterateResult, newTrace, status, err := r.iterativeLookup(ctx, q, ns, depth+1, newLayer, newTrace) - if isStatusAnswer(status) { - r.verboseLog((depth + 1), "--> Auth Resolution success: ", status) + if status == StatusNoNeededGlue { + r.verboseLog((depth + 2), "--> Auth resolution of ", ns, " was unsuccessful. No glue to follow", status) + return iterateResult, newTrace, status, err + } else if isStatusAnswer(status) { + r.verboseLog((depth + 1), "--> Auth Resolution of ", ns, " success: ", status) return iterateResult, newTrace, status, err } else if i+1 < len(result.Authorities) { r.verboseLog((depth + 2), "--> Auth resolution of ", ns, " Failed: ", status, ". Will try next authority") @@ -585,16 +631,25 @@ func (r *Resolver) extractAuthority(ctx context.Context, authority interface{}, // Short circuit a lookup from the glue // Normally this would be handled by caching, but we want to support following glue // that would normally be cache poison. Because it's "ok" and quite common - res, status := checkGlue(server, *result) + res, status := checkGlue(server, *result, r.ipVersionMode, r.iterationIPPreference) if status != StatusNoError { + if ok, _ = nameIsBeneath(server, layer); ok { + // The domain we're searching for is beneath us but no glue was returned. We cannot proceed without this Glue. + // Terminating + return "", StatusNoNeededGlue, "", trace + } // Fall through to normal query var q Question q.Name = server - q.Type = dns.TypeA q.Class = dns.ClassINET + if r.ipVersionMode != IPv4Only && r.iterationIPPreference == PreferIPv6 { + q.Type = dns.TypeAAAA + } else { + q.Type = dns.TypeA + } res, trace, status, _ = r.iterativeLookup(ctx, q, r.randomRootNameServer(), depth+1, ".", trace) } - if status == StatusIterTimeout { + if status == StatusIterTimeout || status == StatusNoNeededGlue { return "", status, "", trace } if status == StatusNoError { @@ -604,9 +659,12 @@ func (r *Resolver) extractAuthority(ctx context.Context, authority interface{}, if !ok { continue } - if innerAns.Type == "A" { + if r.ipVersionMode != IPv6Only && innerAns.Type == "A" { server := strings.TrimSuffix(innerAns.Answer, ".") + ":53" return server, StatusNoError, layer, trace + } else if r.ipVersionMode != IPv4Only && innerAns.Type == "AAAA" { + server := "[" + strings.TrimSuffix(innerAns.Answer, ".") + "]:53" + return server, StatusNoError, layer, trace } } } diff --git a/src/zdns/lookup_test.go b/src/zdns/lookup_test.go index 07aa31a7..f6780b9f 100644 --- a/src/zdns/lookup_test.go +++ b/src/zdns/lookup_test.go @@ -58,9 +58,10 @@ func InitTest(t *testing.T) *ResolverConfig { mc := MockLookupClient{} config := NewResolverConfig() - config.ExternalNameServers = []string{"127.0.0.1:53"} - config.RootNameServers = []string{"127.0.0.1:53"} - config.LocalAddrs = []net.IP{net.ParseIP("127.0.0.1")} + config.ExternalNameServersV4 = []string{"127.0.0.1:53"} + config.RootNameServersV4 = []string{"127.0.0.1:53"} + config.LocalAddrsV4 = []net.IP{net.ParseIP("127.0.0.1")} + config.IPVersionMode = IPv4Only config.LookupClient = mc return config @@ -625,7 +626,7 @@ func TestOneA(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -641,7 +642,7 @@ func TestOneA(t *testing.T) { Protocol: "", Flags: DNSFlags{}, } - res, _, _, _ := resolver.DoTargetedLookup("example.com", ns1, IPv4Only, false) + res, _, _, _ := resolver.DoTargetedLookup("example.com", ns1, false, true, false) verifyResult(t, *res, []string{"192.0.2.1"}, nil) } @@ -653,7 +654,7 @@ func TestTwoA(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -676,7 +677,7 @@ func TestTwoA(t *testing.T) { Protocol: "", Flags: DNSFlags{}, } - res, _, _, _ := resolver.DoTargetedLookup(domain1, ns1, IPv4Only, false) + res, _, _, _ := resolver.DoTargetedLookup(domain1, ns1, false, true, false) verifyResult(t, *res, []string{"192.0.2.1", "192.0.2.2"}, nil) } @@ -688,7 +689,7 @@ func TestQuadAWithoutFlag(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -712,7 +713,7 @@ func TestQuadAWithoutFlag(t *testing.T) { Flags: DNSFlags{}, } - res, _, _, _ := resolver.DoTargetedLookup(domain1, ns1, IPv4Only, false) + res, _, _, _ := resolver.DoTargetedLookup(domain1, ns1, false, true, false) verifyResult(t, *res, []string{"192.0.2.1"}, nil) } @@ -724,7 +725,7 @@ func TestOnlyQuadA(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -741,7 +742,7 @@ func TestOnlyQuadA(t *testing.T) { Flags: DNSFlags{}, } - res, _, _, _ := resolver.DoTargetedLookup(domain1, ns1, IPv6Only, false) + res, _, _, _ := resolver.DoTargetedLookup(domain1, ns1, false, false, true) assert.NotNil(t, res) verifyResult(t, *res, nil, []string{"2001:db8::1"}) } @@ -754,7 +755,7 @@ func TestAandQuadA(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -777,7 +778,7 @@ func TestAandQuadA(t *testing.T) { Protocol: "", Flags: DNSFlags{}, } - res, _, _, _ := resolver.DoTargetedLookup(domain1, ns1, IPv4OrIPv6, false) + res, _, _, _ := resolver.DoTargetedLookup(domain1, ns1, false, true, true) assert.NotNil(t, res) verifyResult(t, *res, []string{"192.0.2.1"}, []string{"2001:db8::1"}) } @@ -790,7 +791,7 @@ func TestTwoQuadA(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -813,7 +814,7 @@ func TestTwoQuadA(t *testing.T) { Protocol: "", Flags: DNSFlags{}, } - res, _, _, _ := resolver.DoTargetedLookup("example.com", ns1, IPv6Only, false) + res, _, _, _ := resolver.DoTargetedLookup("example.com", ns1, false, false, true) assert.NotNil(t, res) verifyResult(t, *res, nil, []string{"2001:db8::1", "2001:db8::2"}) } @@ -827,7 +828,7 @@ func TestNoResults(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -837,7 +838,7 @@ func TestNoResults(t *testing.T) { Protocol: "", Flags: DNSFlags{}, } - res, _, _, _ := resolver.DoTargetedLookup("example.com", ns1, IPv4Only, false) + res, _, _, _ := resolver.DoTargetedLookup("example.com", ns1, false, true, false) verifyResult(t, *res, nil, nil) } @@ -849,7 +850,7 @@ func TestCname(t *testing.T) { require.NoError(t, err) domain1 := "cname.example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -883,7 +884,7 @@ func TestCname(t *testing.T) { Protocol: "", Flags: DNSFlags{}, } - res, _, _, _ := resolver.DoTargetedLookup("cname.example.com", ns1, IPv4Only, false) + res, _, _, _ := resolver.DoTargetedLookup("cname.example.com", ns1, false, true, false) verifyResult(t, *res, []string{"192.0.2.1"}, nil) } @@ -895,7 +896,7 @@ func TestQuadAWithCname(t *testing.T) { require.NoError(t, err) domain1 := "cname.example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -918,7 +919,7 @@ func TestQuadAWithCname(t *testing.T) { Protocol: "", Flags: DNSFlags{}, } - res, _, _, _ := resolver.DoTargetedLookup("cname.example.com", ns1, IPv6Only, false) + res, _, _, _ := resolver.DoTargetedLookup("cname.example.com", ns1, false, false, true) verifyResult(t, *res, nil, []string{"2001:db8::3"}) } @@ -930,7 +931,7 @@ func TestUnexpectedMxOnly(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -947,7 +948,7 @@ func TestUnexpectedMxOnly(t *testing.T) { Flags: DNSFlags{}, } - res, _, status, _ := resolver.DoTargetedLookup("example.com", ns1, IPv4OrIPv6, false) + res, _, status, _ := resolver.DoTargetedLookup("example.com", ns1, false, true, true) if status != StatusError { t.Errorf("Expected ERROR status, got %v", status) @@ -964,7 +965,7 @@ func TestMxAndAdditionals(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -994,7 +995,7 @@ func TestMxAndAdditionals(t *testing.T) { Flags: DNSFlags{}, } - res, _, _, _ := resolver.DoTargetedLookup("example.com", ns1, IPv4OrIPv6, false) + res, _, _, _ := resolver.DoTargetedLookup("example.com", ns1, false, true, true) verifyResult(t, *res, []string{"192.0.2.3"}, []string{"2001:db8::4"}) } @@ -1006,7 +1007,7 @@ func TestMismatchIpType(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -1023,7 +1024,7 @@ func TestMismatchIpType(t *testing.T) { Flags: DNSFlags{}, } - res, _, status, _ := resolver.DoTargetedLookup("example.com", ns1, IPv4OrIPv6, false) + res, _, status, _ := resolver.DoTargetedLookup("example.com", ns1, false, true, true) if status != StatusError { t.Errorf("Expected ERROR status, got %v", status) @@ -1040,7 +1041,7 @@ func TestCnameLoops(t *testing.T) { require.NoError(t, err) domain1 := "cname1.example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -1075,7 +1076,7 @@ func TestCnameLoops(t *testing.T) { Flags: DNSFlags{}, } - res, _, status, _ := resolver.DoTargetedLookup("cname1.example.com", ns1, IPv4OrIPv6, false) + res, _, status, _ := resolver.DoTargetedLookup("cname1.example.com", ns1, false, true, true) if status != StatusError { t.Errorf("Expected ERROR status, got %v", status) @@ -1091,7 +1092,7 @@ func TestExtendedRecursion(t *testing.T) { resolver, err := InitResolver(config) require.NoError(t, err) - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] // Create a CNAME chain of length > 10 for i := 1; i < 12; i++ { domainNSRecord := domainNS{ @@ -1113,7 +1114,7 @@ func TestExtendedRecursion(t *testing.T) { } } - res, _, status, _ := resolver.DoTargetedLookup("cname1.example.com", ns1, IPv4OrIPv6, false) + res, _, status, _ := resolver.DoTargetedLookup("cname1.example.com", ns1, false, true, true) if status != StatusError { t.Errorf("Expected ERROR status, got %v", status) @@ -1130,7 +1131,7 @@ func TestEmptyNonTerminal(t *testing.T) { require.NoError(t, err) domain1 := "leaf.intermediate.example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -1159,11 +1160,11 @@ func TestEmptyNonTerminal(t *testing.T) { Flags: DNSFlags{}, } // Verify leaf returns correctly - res, _, _, _ := resolver.DoTargetedLookup("leaf.intermediate.example.com", ns1, IPv4Only, false) + res, _, _, _ := resolver.DoTargetedLookup("leaf.intermediate.example.com", ns1, false, true, false) verifyResult(t, *res, []string{"192.0.2.3"}, nil) // Verify empty non-terminal returns no answer - res, _, _, _ = resolver.DoTargetedLookup("intermediate.example.com", ns1, IPv4OrIPv6, false) + res, _, _, _ = resolver.DoTargetedLookup("intermediate.example.com", ns1, false, true, true) verifyResult(t, *res, nil, nil) } @@ -1173,8 +1174,8 @@ func TestNXDomain(t *testing.T) { config := InitTest(t) resolver, err := InitResolver(config) require.NoError(t, err) - ns1 := config.ExternalNameServers[0] - res, _, status, _ := resolver.DoTargetedLookup("nonexistent.example.com", ns1, IPv4OrIPv6, false) + ns1 := config.ExternalNameServersV4[0] + res, _, status, _ := resolver.DoTargetedLookup("nonexistent.example.com", ns1, false, true, true) if status != StatusNXDomain { t.Errorf("Expected StatusNXDomain status, got %v", status) } else if res != nil { @@ -1191,7 +1192,7 @@ func TestAandQuadADedup(t *testing.T) { domain1 := "cname1.example.com" domain2 := "cname2.example.com" domain3 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} domainNS2 := domainNS{domain: domain2, ns: ns1} domainNS3 := domainNS{domain: domain3, ns: ns1} @@ -1274,7 +1275,7 @@ func TestAandQuadADedup(t *testing.T) { Flags: DNSFlags{}, } - res, _, _, _ := resolver.DoTargetedLookup(domain1, ns1, IPv4OrIPv6, false) + res, _, _, _ := resolver.DoTargetedLookup(domain1, ns1, false, true, true) assert.NotNil(t, res) verifyResult(t, *res, []string{"192.0.2.1"}, []string{"2001:db8::3"}) } @@ -1287,14 +1288,14 @@ func TestServFail(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{} name := "example.com" protocolStatus[domainNS1] = StatusServFail - res, _, finalStatus, _ := resolver.DoTargetedLookup(name, ns1, IPv4OrIPv6, false) + res, _, finalStatus, _ := resolver.DoTargetedLookup(name, ns1, false, true, true) if finalStatus != protocolStatus[domainNS1] { t.Errorf("Expected %v status, got %v", protocolStatus, finalStatus) @@ -1324,7 +1325,7 @@ func TestNsAInAdditional(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -1356,7 +1357,7 @@ func TestNsAInAdditional(t *testing.T) { IPv4Addresses: []string{"192.0.2.3"}, IPv6Addresses: nil, } - res, _, _, _ := resolver.DoNSLookup("example.com", ns1, false) + res, _, _, _ := resolver.DoNSLookup("example.com", ns1, false, true, false) verifyNsResult(t, res.Servers, expectedServersMap) } @@ -1367,7 +1368,7 @@ func TestTwoNSInAdditional(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -1417,18 +1418,18 @@ func TestTwoNSInAdditional(t *testing.T) { IPv4Addresses: []string{"192.0.2.4"}, IPv6Addresses: nil, } - res, _, _, _ := resolver.DoNSLookup("example.com", ns1, false) + res, _, _, _ := resolver.DoNSLookup("example.com", ns1, false, true, false) verifyNsResult(t, res.Servers, expectedServersMap) } func TestAandQuadAInAdditional(t *testing.T) { config := InitTest(t) - config.IPVersionMode = IPv4OrIPv6 + //config.IPVersionMode = IPv4OrIPv6 resolver, err := InitResolver(config) require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -1467,18 +1468,18 @@ func TestAandQuadAInAdditional(t *testing.T) { IPv4Addresses: []string{"192.0.2.3"}, IPv6Addresses: []string{"2001:db8::4"}, } - res, _, _, _ := resolver.DoNSLookup("example.com", ns1, false) + res, _, _, _ := resolver.DoNSLookup("example.com", ns1, false, true, true) verifyNsResult(t, res.Servers, expectedServersMap) } func TestNsMismatchIpType(t *testing.T) { config := InitTest(t) - config.IPVersionMode = IPv4OrIPv6 + //config.IPVersionMode = IPv4OrIPv6 resolver, err := InitResolver(config) require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -1517,18 +1518,18 @@ func TestNsMismatchIpType(t *testing.T) { IPv4Addresses: nil, IPv6Addresses: nil, } - res, _, _, _ := resolver.DoNSLookup("example.com", ns1, false) + res, _, _, _ := resolver.DoNSLookup("example.com", ns1, false, true, true) verifyNsResult(t, res.Servers, expectedServersMap) } func TestAandQuadALookup(t *testing.T) { config := InitTest(t) - config.IPVersionMode = IPv4OrIPv6 + //config.IPVersionMode = IPv4OrIPv6 resolver, err := InitResolver(config) require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -1579,7 +1580,7 @@ func TestAandQuadALookup(t *testing.T) { IPv4Addresses: []string{"192.0.2.3"}, IPv6Addresses: []string{"2001:db8::4"}, } - res, _, _, _ := resolver.DoNSLookup("example.com", ns1, false) + res, _, _, _ := resolver.DoNSLookup("example.com", ns1, false, true, true) verifyNsResult(t, res.Servers, expectedServersMap) } @@ -1588,9 +1589,9 @@ func TestNsNXDomain(t *testing.T) { resolver, err := InitResolver(config) require.NoError(t, err) - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] - _, _, status, _ := resolver.DoNSLookup("nonexistentexample.com", ns1, false) + _, _, status, _ := resolver.DoNSLookup("nonexistentexample.com", ns1, false, true, true) assert.Equal(t, StatusNXDomain, status) } @@ -1601,13 +1602,13 @@ func TestNsServFail(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{} protocolStatus[domainNS1] = StatusServFail - res, _, status, _ := resolver.DoNSLookup("example.com", ns1, false) + res, _, status, _ := resolver.DoNSLookup("example.com", ns1, false, true, false) assert.Equal(t, status, protocolStatus[domainNS1]) assert.Empty(t, res.Servers) @@ -1619,7 +1620,7 @@ func TestErrorInTargetedLookup(t *testing.T) { require.NoError(t, err) domain1 := "example.com" - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -1640,7 +1641,7 @@ func TestErrorInTargetedLookup(t *testing.T) { protocolStatus[domainNS1] = StatusError - res, _, status, _ := resolver.DoNSLookup("example.com", ns1, false) + res, _, status, _ := resolver.DoNSLookup("example.com", ns1, false, true, false) assert.Empty(t, len(res.Servers), 0) assert.Equal(t, status, protocolStatus[domainNS1]) } @@ -1648,16 +1649,14 @@ func TestErrorInTargetedLookup(t *testing.T) { // Test One NS with one IP with only ipv4-lookup func TestAllNsLookupOneNs(t *testing.T) { config := InitTest(t) - config.LocalAddrs = []net.IP{net.ParseIP(LoopbackAddrString)} - config.IPVersionMode = IPv4OrIPv6 + config.LocalAddrsV4 = []net.IP{net.ParseIP("127.0.0.1")} resolver, err := InitResolver(config) require.NoError(t, err) - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domain1 := "example.com" nsDomain1 := "ns1.example.com" ipv4_1 := "127.0.0.2" - ipv6_1 := "::1" domainNS1 := domainNS{domain: domain1, ns: ns1} mockResults[domainNS1] = SingleQueryResult{ @@ -1678,13 +1677,6 @@ func TestAllNsLookupOneNs(t *testing.T) { Name: nsDomain1 + ".", Answer: ipv4_1, }, - Answer{ - TTL: 3600, - Type: "AAAA", - Class: "IN", - Name: nsDomain1 + ".", - Answer: ipv6_1, - }, }, Authorities: nil, Protocol: "", @@ -1710,36 +1702,12 @@ func TestAllNsLookupOneNs(t *testing.T) { Flags: DNSFlags{}, } - ns3 := net.JoinHostPort(ipv6_1, "53") - domainNS3 := domainNS{domain: domain1, ns: ns3} - ipv4_3 := "127.0.0.4" - mockResults[domainNS3] = SingleQueryResult{ - Answers: []interface{}{ - Answer{ - TTL: 3600, - Type: "A", - Class: "IN", - Name: "example.com.", - Answer: ipv4_3, - }, - }, - Additional: nil, - Authorities: nil, - Protocol: "", - Flags: DNSFlags{}, - } - expectedRes := []ExtendedResult{ { Nameserver: nsDomain1, Status: StatusNoError, Res: mockResults[domainNS2], }, - { - Nameserver: nsDomain1, - Status: StatusNoError, - Res: mockResults[domainNS3], - }, } q := Question{ Type: dns.TypeNS, @@ -1760,7 +1728,7 @@ func TestAllNsLookupOneNsMultipleIps(t *testing.T) { resolver, err := InitResolver(config) require.NoError(t, err) - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domain1 := "example.com" nsDomain1 := "ns1.example.com" ipv4_1 := "127.0.0.2" @@ -1883,7 +1851,7 @@ func TestAllNsLookupTwoNs(t *testing.T) { resolver, err := InitResolver(config) require.NoError(t, err) - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domain1 := "example.com" nsDomain1 := "ns1.example.com" nsDomain2 := "ns2.example.com" @@ -1999,7 +1967,7 @@ func TestAllNsLookupErrorInOne(t *testing.T) { resolver, err := InitResolver(config) require.NoError(t, err) - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domain1 := "example.com" nsDomain1 := "ns1.example.com" ipv4_1 := "127.0.0.2" @@ -2099,7 +2067,7 @@ func TestAllNsLookupNXDomain(t *testing.T) { resolver, err := InitResolver(config) require.NoError(t, err) - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] q := Question{ Type: dns.TypeNS, Class: dns.ClassINET, @@ -2119,7 +2087,7 @@ func TestAllNsLookupServFail(t *testing.T) { resolver, err := InitResolver(config) require.NoError(t, err) - ns1 := config.ExternalNameServers[0] + ns1 := config.ExternalNameServersV4[0] domain1 := "example.com" domainNS1 := domainNS{domain: domain1, ns: ns1} @@ -2140,8 +2108,8 @@ func TestAllNsLookupServFail(t *testing.T) { func TestInvalidInputsLookup(t *testing.T) { config := InitTest(t) - config.LocalAddrs = []net.IP{net.ParseIP("127.0.0.1")} - config.ExternalNameServers = []string{"127.0.0.1:53"} + config.LocalAddrsV4 = []net.IP{net.ParseIP("127.0.0.1")} + config.ExternalNameServersV4 = []string{"127.0.0.1:53"} resolver, err := InitResolver(config) require.NoError(t, err) q := Question{ diff --git a/src/zdns/nslookup.go b/src/zdns/nslookup.go index 3a306f97..a45c603e 100644 --- a/src/zdns/nslookup.go +++ b/src/zdns/nslookup.go @@ -39,10 +39,13 @@ type NSResult struct { } // DoNSLookup performs a DNS NS lookup on the given name against the given name server. -func (r *Resolver) DoNSLookup(lookupName, nameServer string, isIterative bool) (*NSResult, Trace, Status, error) { +func (r *Resolver) DoNSLookup(lookupName, nameServer string, isIterative, lookupA, lookupAAAA bool) (*NSResult, Trace, Status, error) { if len(lookupName) == 0 { return nil, nil, "", errors.New("no name provided for NS lookup") } + if !lookupA && !lookupAAAA { + return nil, nil, "", errors.New("must lookup either A or AAAA") + } var trace Trace var ns *SingleQueryResult @@ -93,17 +96,14 @@ func (r *Resolver) DoNSLookup(lookupName, nameServer string, isIterative bool) ( var findIpv4 = false var findIpv6 = false - lookupIPv4 := r.ipVersionMode == IPv4Only || r.ipVersionMode == IPv4OrIPv6 - lookupIPv6 := r.ipVersionMode == IPv6Only || r.ipVersionMode == IPv4OrIPv6 - - if lookupIPv4 { + if lookupA { if ips, ok := ipv4s[rec.Name]; ok { rec.IPv4Addresses = ips } else { findIpv4 = true } } - if lookupIPv6 { + if lookupAAAA { if ips, ok := ipv6s[rec.Name]; ok { rec.IPv6Addresses = ips } else { @@ -111,7 +111,7 @@ func (r *Resolver) DoNSLookup(lookupName, nameServer string, isIterative bool) ( } } if findIpv4 || findIpv6 { - res, nextTrace, _, _ := r.DoTargetedLookup(rec.Name, nameServer, r.ipVersionMode, false) + res, nextTrace, _, _ := r.DoTargetedLookup(rec.Name, nameServer, false, lookupA, lookupAAAA) if res != nil { if findIpv4 { rec.IPv4Addresses = res.IPv4Addresses diff --git a/src/zdns/resolver.go b/src/zdns/resolver.go index e19b6115..3af74be8 100644 --- a/src/zdns/resolver.go +++ b/src/zdns/resolver.go @@ -19,7 +19,6 @@ import ( "math/rand" "net" "strings" - "sync" "time" "github.com/pkg/errors" @@ -32,9 +31,6 @@ import ( ) const ( - // TODO - we'll need to update this when we add IPv6 support - LoopbackAddrString = "127.0.0.1" - defaultTimeout = 15 * time.Second // timeout for resolving a single name defaultIterativeTimeout = 4 * time.Second // timeout for single iteration in an iterative query defaultTransportMode = UDPOrTCP @@ -49,36 +45,40 @@ const ( defaultShouldTrace = false defaultDNSSECEnabled = false defaultIPVersionMode = IPv4Only + defaultIterationIPPreference = PreferIPv4 DefaultNameServerConfigFile = "/etc/resolv.conf" defaultLookupAllNameServers = false ) // ResolverConfig is a struct that holds all the configuration options for a Resolver. It is used to create a new Resolver. type ResolverConfig struct { - sync.Mutex Cache *Cache CacheSize int // don't use both cache and cacheSize LookupClient Lookuper // either a functional or mock Lookuper client for testing Blacklist *blacklist.SafeBlacklist - LocalAddrs []net.IP // local addresses to use for connections, one will be selected at random for the resolver + LocalAddrsV4 []net.IP // ipv4 local addresses to use for connections, one will be selected at random for the resolver + LocalAddrsV6 []net.IP // ipv6 local addresses to use for connections, one will be selected at random for the resolver Retries int LogLevel log.Level - TransportMode transportMode - IPVersionMode IPVersionMode - ShouldRecycleSockets bool - - IterativeTimeout time.Duration // applicable to iterative queries only, timeout for a single iteration step - Timeout time.Duration // timeout for the resolution of a single name - MaxDepth int - ExternalNameServers []string // name servers used for external lookups - RootNameServers []string // root servers used for iterative lookups - LookupAllNameServers bool // perform the lookup via all the nameservers for the domain - FollowCNAMEs bool // whether iterative lookups should follow CNAMEs/DNAMEs - DNSConfigFilePath string // path to the DNS config file, ex: /etc/resolv.conf + TransportMode transportMode + IPVersionMode IPVersionMode + IterationIPPreference IterationIPPreference // preference for IPv4 or IPv6 lookups in iterative queries + ShouldRecycleSockets bool + + IterativeTimeout time.Duration // applicable to iterative queries only, timeout for a single iteration step + Timeout time.Duration // timeout for the resolution of a single name + MaxDepth int + ExternalNameServersV4 []string // v4 name servers used for external lookups + ExternalNameServersV6 []string // v6 name servers used for external lookups + RootNameServersV4 []string // v4 root servers used for iterative lookups + RootNameServersV6 []string // v6 root servers used for iterative lookups + LookupAllNameServers bool // perform the lookup via all the nameservers for the domain + FollowCNAMEs bool // whether iterative lookups should follow CNAMEs/DNAMEs + DNSConfigFilePath string // path to the DNS config file, ex: /etc/resolv.conf DNSSecEnabled bool EdnsOptions []dns.EDNS0 @@ -99,64 +99,58 @@ func (rc *ResolverConfig) Validate() error { } // External Nameservers - if len(rc.ExternalNameServers) == 0 { - return errors.New("must have at least one external name server") + if rc.IPVersionMode != IPv6Only && len(rc.ExternalNameServersV4) == 0 { + // If IPv4 is supported, we require at least one IPv4 external nameserver + return errors.New("must have at least one external IPv4 name server if IPv4 mode is enabled") + } + if rc.IPVersionMode != IPv4Only && len(rc.ExternalNameServersV6) == 0 { + // If IPv6 is supported, we require at least one IPv6 external nameserver + return errors.New("must have at least one external IPv6 name server if IPv6 mode is enabled") } - for _, ns := range rc.ExternalNameServers { + // Validate all nameservers have ports and are valid IPs + for _, ns := range util.Concat(rc.ExternalNameServersV4, rc.ExternalNameServersV6) { ipString, _, err := net.SplitHostPort(ns) if err != nil { - return fmt.Errorf("could not parse external name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53", ns) + return fmt.Errorf("could not parse external name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53 or [::1]:53", ns) } ip := net.ParseIP(ipString) if ip == nil { - return fmt.Errorf("could not parse external name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53", ns) + return fmt.Errorf("could not parse external name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53 or [::1]:53", ns) } } - // Check Root Servers - if len(rc.RootNameServers) == 0 { - return errors.New("must have at least one root name server") + // Root Nameservers + if rc.IPVersionMode != IPv6Only && len(rc.RootNameServersV4) == 0 { + // If IPv4 is supported, we require at least one IPv4 root nameserver + return errors.New("must have at least one root IPv4 name server if IPv4 mode is enabled") } - for _, ns := range rc.RootNameServers { + if rc.IPVersionMode != IPv4Only && len(rc.RootNameServersV6) == 0 { + // If IPv6 is supported, we require at least one IPv6 root nameserver + return errors.New("must have at least one root IPv6 name server if IPv6 mode is enabled") + } + + // Validate all nameservers have ports and are valid IPs + for _, ns := range util.Concat(rc.RootNameServersV4, rc.RootNameServersV6) { ipString, _, err := net.SplitHostPort(ns) if err != nil { - return fmt.Errorf("could not parse root name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53", ns) + return fmt.Errorf("could not parse root name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53 or [::1]:53", ns) } ip := net.ParseIP(ipString) if ip == nil { - return fmt.Errorf("could not parse root name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53", ns) - } - } - - // TODO - Remove when we add IPv6 support - for _, ns := range rc.RootNameServers { - // we know ns passed validation above - ip, _, err := util.SplitHostPort(ns) - if err != nil { - return errors.Wrapf(err, "could not split host and port for root nameserver: %s", ns) - } - if util.IsIPv6(&ip) { - return fmt.Errorf("IPv6 root nameservers are not supported: %s", ns) + return fmt.Errorf("could not parse root name server (%s), must be valid IP and have port appended, ex: 1.2.3.4:53 or [::1]:53", ns) } } - for _, ns := range rc.ExternalNameServers { - // we know ns passed validation above - ip, _, err := util.SplitHostPort(ns) - if err != nil { - return errors.Wrapf(err, "could not split host and port for external nameserver: %s", ns) - } - if util.IsIPv6(&ip) { - return fmt.Errorf("IPv6 extenral nameservers are not supported: %s", ns) - } - } - // TODO end IPv6 section // Local Addresses - if len(rc.LocalAddrs) == 0 { - return errors.New("must have a local address to send traffic from") + if rc.IPVersionMode != IPv6Only && len(rc.LocalAddrsV4) == 0 { + return errors.New("must have a local IPv4 address to send traffic from") + } + if rc.IPVersionMode != IPv4Only && len(rc.LocalAddrsV6) == 0 { + return errors.New("must have a local IPv6 address to send traffic from") } - for _, ip := range rc.LocalAddrs { + // Validate all local addresses are valid IPs + for _, ip := range util.Concat(rc.LocalAddrsV4, rc.LocalAddrsV6) { if ip == nil { return errors.New("local address cannot be nil") } @@ -165,13 +159,37 @@ func (rc *ResolverConfig) Validate() error { } } - // TODO - Remove when we add IPv6 support - for _, addr := range rc.LocalAddrs { - if util.IsIPv6(&addr) { - return fmt.Errorf("IPv6 local addresses are not supported: %v", rc.LocalAddrs) + // Validate IPv4 local addresses are IPv4 + for _, ip := range rc.LocalAddrsV4 { + if ip.To4() == nil { + return fmt.Errorf("local address is not IPv4: %v", ip) + } + } + + // Validate IPv6 local addresses are IPv6 + for _, ip := range rc.LocalAddrsV6 { + if !util.IsIPv6(&ip) { + return fmt.Errorf("IPv6 local address (%v) is not IPv6", ip) + } + } + + // Ensure no IPv6 link-local/multicast local addresses are used + for _, ip := range rc.LocalAddrsV6 { + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return fmt.Errorf("link-local IPv6 local addresses are not supported: %v", ip) + } + } + + // Ensure no IPv6 link-local/multicast external/root nameservers are used + for _, ns := range util.Concat(rc.ExternalNameServersV6, rc.RootNameServersV6) { + ip, _, err := util.SplitHostPort(ns) + if err != nil { + return errors.Wrapf(err, "could not split host and port for nameserver: %s", ns) + } + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return fmt.Errorf("link-local IPv6 external/root nameservers are not supported: %v", ip) } } - // TODO end IPv6 section if err := rc.validateLoopbackConsistency(); err != nil { return errors.Wrap(err, "could not validate loopback consistency") @@ -183,20 +201,16 @@ func (rc *ResolverConfig) Validate() error { // validateLoopbackConsistency checks that the following is true // - either all nameservers AND all local addresses are loopback, or none are func (rc *ResolverConfig) validateLoopbackConsistency() error { - allIPsLength := len(rc.LocalAddrs) + len(rc.RootNameServers) + len(rc.ExternalNameServers) + allLocalAddrs := util.Concat(rc.LocalAddrsV4, rc.LocalAddrsV6) + allExternalNameServers := util.Concat(rc.ExternalNameServersV4, rc.ExternalNameServersV6) + allRootNameServers := util.Concat(rc.RootNameServersV4, rc.RootNameServersV6) + allIPsLength := len(allLocalAddrs) + len(allExternalNameServers) + len(allRootNameServers) allIPs := make([]net.IP, 0, allIPsLength) - allIPs = append(allIPs, rc.LocalAddrs...) - for _, ns := range rc.ExternalNameServers { + allIPs = append(allIPs, allLocalAddrs...) + for _, ns := range util.Concat(allExternalNameServers, allRootNameServers) { ip, _, err := util.SplitHostPort(ns) if err != nil { - return errors.Wrapf(err, "could not split host and port for external nameserver: %s", ns) - } - allIPs = append(allIPs, ip) - } - for _, ns := range rc.RootNameServers { - ip, _, err := util.SplitHostPort(ns) - if err != nil { - return errors.Wrapf(err, "could not split host and port for root nameserver: %s", ns) + return errors.Wrapf(err, "could not split host and port for nameserver: %s", ns) } allIPs = append(allIPs, ip) } @@ -210,15 +224,15 @@ func (rc *ResolverConfig) validateLoopbackConsistency() error { } } if allIPsLoopback == noneIPsLoopback { - return fmt.Errorf("cannot mix loopback and non-loopback local addresses (%v) and name servers (%v)", rc.LocalAddrs, util.Concat(rc.ExternalNameServers, rc.RootNameServers)) + return fmt.Errorf("cannot mix loopback and non-loopback local addresses (%v) and name servers (%v)", allLocalAddrs, util.Concat(allExternalNameServers, allRootNameServers)) } return nil } func (rc *ResolverConfig) PrintInfo() { - log.Infof("using local addresses: %v", rc.LocalAddrs) - log.Infof("for non-iterative lookups, using external nameservers: %s", strings.Join(rc.ExternalNameServers, ", ")) - log.Infof("for iterative lookups, using nameservers: %s", strings.Join(rc.RootNameServers, ", ")) + log.Infof("using local addresses: %v", util.Concat(rc.LocalAddrsV4, rc.LocalAddrsV6)) + log.Infof("for non-iterative lookups, using external nameservers: %s", strings.Join(util.Concat(rc.ExternalNameServersV4, rc.ExternalNameServersV6), ", ")) + log.Infof("for iterative lookups, using nameservers: %s", strings.Join(util.Concat(rc.RootNameServersV4, rc.RootNameServersV6), ", ")) } // NewResolverConfig creates a new ResolverConfig with default values. @@ -229,14 +243,16 @@ func NewResolverConfig() *ResolverConfig { LookupClient: LookupClient{}, Cache: c, - Blacklist: blacklist.New(), - LocalAddrs: nil, + Blacklist: blacklist.New(), + LocalAddrsV4: []net.IP{}, + LocalAddrsV6: []net.IP{}, - TransportMode: defaultTransportMode, - IPVersionMode: defaultIPVersionMode, - ShouldRecycleSockets: defaultShouldRecycleSockets, - LookupAllNameServers: false, - FollowCNAMEs: defaultFollowCNAMEs, + TransportMode: defaultTransportMode, + IPVersionMode: defaultIPVersionMode, + IterationIPPreference: defaultIterationIPPreference, + ShouldRecycleSockets: defaultShouldRecycleSockets, + LookupAllNameServers: false, + FollowCNAMEs: defaultFollowCNAMEs, Retries: defaultRetries, LogLevel: defaultLogVerbosity, @@ -250,6 +266,13 @@ func NewResolverConfig() *ResolverConfig { } } +type ConnectionInfo struct { + udpClient *dns.Client + tcpClient *dns.Client + conn *dns.Conn + localAddr net.IP +} + // Resolver is a struct that holds the state of a DNS resolver. It is used to perform DNS lookups. type Resolver struct { cache *Cache @@ -257,17 +280,16 @@ type Resolver struct { blacklist *blacklist.SafeBlacklist - udpClient *dns.Client - tcpClient *dns.Client - conn *dns.Conn - localAddr net.IP + connInfoIPv4 *ConnectionInfo + connInfoIPv6 *ConnectionInfo retries int logLevel log.Level - transportMode transportMode - ipVersionMode IPVersionMode - shouldRecycleSockets bool + transportMode transportMode + ipVersionMode IPVersionMode + iterationIPPreference IterationIPPreference + shouldRecycleSockets bool iterativeTimeout time.Duration timeout time.Duration // timeout for the network conns @@ -311,10 +333,11 @@ func InitResolver(config *ResolverConfig) (*Resolver, error) { logLevel: config.LogLevel, lookupAllNameServers: config.LookupAllNameServers, - transportMode: config.TransportMode, - ipVersionMode: config.IPVersionMode, - shouldRecycleSockets: config.ShouldRecycleSockets, - followCNAMEs: config.FollowCNAMEs, + transportMode: config.TransportMode, + ipVersionMode: config.IPVersionMode, + iterationIPPreference: config.IterationIPPreference, + shouldRecycleSockets: config.ShouldRecycleSockets, + followCNAMEs: config.FollowCNAMEs, timeout: config.Timeout, @@ -323,57 +346,95 @@ func InitResolver(config *ResolverConfig) (*Resolver, error) { checkingDisabledBit: config.CheckingDisabledBit, } log.SetLevel(r.logLevel) - r.localAddr = config.LocalAddrs[rand.Intn(len(config.LocalAddrs))] + if config.IPVersionMode != IPv6Only { + // create connection info for IPv4 + connInfo, err := getConnectionInfo(config.LocalAddrsV4, config.TransportMode, config.Timeout, config.ShouldRecycleSockets) + if err != nil { + return nil, fmt.Errorf("could not create connection info for IPv4: %w", err) + } + r.connInfoIPv4 = connInfo + } + if config.IPVersionMode != IPv4Only { + // create connection info for IPv6 + connInfo, err := getConnectionInfo(config.LocalAddrsV6, config.TransportMode, config.Timeout, config.ShouldRecycleSockets) + if err != nil { + return nil, fmt.Errorf("could not create connection info for IPv6: %w", err) + } + r.connInfoIPv6 = connInfo + } + // need to deep-copy here so we're not reliant on the state of the resolver config post-resolver creation + r.externalNameServers = make([]string, 0) + if config.IPVersionMode == IPv4Only || config.IPVersionMode == IPv4OrIPv6 { + ipv4Nameservers := make([]string, len(config.ExternalNameServersV4)) + // copy over IPv4 nameservers + elemsCopied := copy(ipv4Nameservers, config.ExternalNameServersV4) + if elemsCopied != len(config.ExternalNameServersV4) { + log.Fatal("failed to copy entire IPv4 name servers list from config") + } + r.externalNameServers = append(r.externalNameServers, ipv4Nameservers...) + } + ipv6Nameservers := make([]string, len(config.ExternalNameServersV6)) + if config.IPVersionMode == IPv6Only || config.IPVersionMode == IPv4OrIPv6 { + // copy over IPv6 nameservers + elemsCopied := copy(ipv6Nameservers, config.ExternalNameServersV6) + if elemsCopied != len(config.ExternalNameServersV6) { + log.Fatal("failed to copy entire IPv6 name servers list from config") + } + r.externalNameServers = append(r.externalNameServers, ipv6Nameservers...) + } + // deep copy external name servers from config to resolver + r.iterativeTimeout = config.IterativeTimeout + r.maxDepth = config.MaxDepth + r.rootNameServers = make([]string, 0, len(config.RootNameServersV4)+len(config.RootNameServersV6)) + if r.ipVersionMode != IPv6Only && len(config.RootNameServersV4) == 0 { + // add IPv4 root servers + r.rootNameServers = append(r.rootNameServers, RootServersV4...) + } else if r.ipVersionMode != IPv6Only { + r.rootNameServers = append(r.rootNameServers, config.RootNameServersV4...) + } + if r.ipVersionMode != IPv4Only && len(config.RootNameServersV6) == 0 { + // add IPv6 root servers + r.rootNameServers = append(r.rootNameServers, RootServersV6...) + } else if r.ipVersionMode != IPv4Only { + r.rootNameServers = append(r.rootNameServers, config.RootNameServersV6...) + } + return r, nil +} - if r.shouldRecycleSockets { +func getConnectionInfo(localAddr []net.IP, transportMode transportMode, timeout time.Duration, shouldRecycleSockets bool) (*ConnectionInfo, error) { + connInfo := &ConnectionInfo{ + localAddr: localAddr[rand.Intn(len(localAddr))], + } + if shouldRecycleSockets { // create persistent connection - conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: r.localAddr}) + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: connInfo.localAddr}) if err != nil { return nil, fmt.Errorf("unable to create UDP connection: %w", err) } - r.conn = new(dns.Conn) - r.conn.Conn = conn + connInfo.conn = new(dns.Conn) + connInfo.conn.Conn = conn } - usingUDP := r.transportMode == UDPOrTCP || r.transportMode == UDPOnly + usingUDP := transportMode == UDPOrTCP || transportMode == UDPOnly if usingUDP { - r.udpClient = new(dns.Client) - r.udpClient.Timeout = r.timeout - r.udpClient.Dialer = &net.Dialer{ - Timeout: r.timeout, - LocalAddr: &net.UDPAddr{IP: r.localAddr}, + connInfo.udpClient = new(dns.Client) + connInfo.udpClient.Timeout = timeout + connInfo.udpClient.Dialer = &net.Dialer{ + Timeout: timeout, + LocalAddr: &net.UDPAddr{IP: connInfo.localAddr}, } } - usingTCP := r.transportMode == UDPOrTCP || r.transportMode == TCPOnly + usingTCP := transportMode == UDPOrTCP || transportMode == TCPOnly if usingTCP { - r.tcpClient = new(dns.Client) - r.tcpClient.Net = "tcp" - r.tcpClient.Timeout = r.timeout - r.tcpClient.Dialer = &net.Dialer{ - Timeout: config.Timeout, - LocalAddr: &net.TCPAddr{IP: r.localAddr}, - } - } - r.externalNameServers = make([]string, len(config.ExternalNameServers)) - // deep copy external name servers from config to resolver - elemsCopied := copy(r.externalNameServers, config.ExternalNameServers) - if elemsCopied != len(config.ExternalNameServers) { - log.Fatal("failed to copy entire name servers list from config") - } - r.iterativeTimeout = config.IterativeTimeout - r.maxDepth = config.MaxDepth - // use the set of 13 root name servers - if len(config.RootNameServers) == 0 { - r.rootNameServers = RootServersV4[:] - } else { - r.rootNameServers = make([]string, len(config.RootNameServers)) - // deep copy root name servers from config to resolver - elemsCopied = copy(r.rootNameServers, config.RootNameServers) - if elemsCopied != len(config.RootNameServers) { - log.Fatal("failed to copy entire root name servers list from config") + connInfo.tcpClient = new(dns.Client) + connInfo.tcpClient.Net = "tcp" + connInfo.tcpClient.Timeout = timeout + connInfo.tcpClient.Dialer = &net.Dialer{ + Timeout: timeout, + LocalAddr: &net.TCPAddr{IP: connInfo.localAddr}, } } - return r, nil + return connInfo, nil } // ExternalLookup performs a single lookup of a DNS question, q, against an external name server. @@ -394,23 +455,27 @@ func (r *Resolver) ExternalLookup(q *Question, dstServer string) (*SingleQueryRe } dstServerWithPort, err := util.AddDefaultPortToDNSServerName(dstServer) if err != nil { - // TODO update below when adding IPv6, add ex. IPv6 - return nil, nil, StatusIllegalInput, fmt.Errorf("could not parse name server (%s): %w. Correct format IPv4 (1.1.1.1:53)", dstServer, err) + return nil, nil, StatusIllegalInput, fmt.Errorf("could not parse name server (%s): %w. Correct format IPv4 1.1.1.1:53 or IPv6 [::1]:53", dstServer, err) } if dstServer != dstServerWithPort { log.Info("no port provided for external lookup, using default port 53") } dstServerIP, _, err := util.SplitHostPort(dstServerWithPort) if err != nil { - // TODO update below when adding IPv6, add ex. IPv6 - return nil, nil, StatusIllegalInput, fmt.Errorf("could not parse name server (%s): %w. Correct format IPv4 (1.1.1.1:53)", dstServer, err) + return nil, nil, StatusIllegalInput, fmt.Errorf("could not parse name server (%s): %w. Correct format IPv4 1.1.1.1:53 or IPv6 [::1]:53", dstServer, err) + } + if util.IsIPv6(&dstServerIP) && r.connInfoIPv6 == nil { + return nil, nil, StatusIllegalInput, fmt.Errorf("IPv6 external lookup requested for domain %s but no IPv6 local addresses provided to resolver", q.Name) + } else if dstServerIP.To4() != nil && r.connInfoIPv4 == nil { + return nil, nil, StatusIllegalInput, fmt.Errorf("IPv4 external lookup requested for domain %s but no IPv4 local addresses provided to resolver", q.Name) } // check that local address and dstServer's don't have a loopback mismatch - if r.localAddr.IsLoopback() != dstServerIP.IsLoopback() { + if dstServerIP.To4() != nil && r.connInfoIPv4.localAddr.IsLoopback() != dstServerIP.IsLoopback() { + return nil, nil, StatusIllegalInput, errors.New("cannot mix loopback and non-loopback addresses") + } else if util.IsIPv6(&dstServerIP) && r.connInfoIPv6.localAddr.IsLoopback() != dstServerIP.IsLoopback() { return nil, nil, StatusIllegalInput, errors.New("cannot mix loopback and non-loopback addresses") - } - // dstServer has been validated and has a port + // dstServer has been validated and has a port, continue with lookup dstServer = dstServerWithPort lookup, trace, status, err := r.lookupClient.DoSingleDstServerLookup(r, *q, dstServer, false) return lookup, trace, status, err @@ -432,9 +497,14 @@ func (r *Resolver) IterativeLookup(q *Question) (*SingleQueryResult, Trace, Stat // Close cleans up any resources used by the resolver. This should be called when the resolver is no longer needed. // Lookup will panic if called after Close. func (r *Resolver) Close() { - if r.conn != nil { - if err := r.conn.Close(); err != nil { - log.Errorf("error closing connection: %v", err) + if r.connInfoIPv4.conn != nil { + if err := r.connInfoIPv4.conn.Close(); err != nil { + log.Errorf("error closing IPv4 connection: %v", err) + } + } + if r.connInfoIPv6.conn != nil { + if err := r.connInfoIPv6.conn.Close(); err != nil { + log.Errorf("error closing IPv6 connection: %v", err) } } } diff --git a/src/zdns/resolver_test.go b/src/zdns/resolver_test.go index 0e412351..3435aa7e 100644 --- a/src/zdns/resolver_test.go +++ b/src/zdns/resolver_test.go @@ -24,51 +24,51 @@ import ( func TestResolverConfig_Validate(t *testing.T) { t.Run("Valid config with external/root name servers and local addr", func(t *testing.T) { rc := &ResolverConfig{ - ExternalNameServers: []string{"127.0.0.53:53"}, - RootNameServers: []string{"127.0.0.53:53"}, - LocalAddrs: []net.IP{net.ParseIP("127.0.0.1")}, + ExternalNameServersV4: []string{"127.0.0.53:53"}, + RootNameServersV4: []string{"127.0.0.53:53"}, + LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, } err := rc.Validate() require.Nil(t, err, "Expected no error but got %v", err) }) t.Run("Using external nameserver with no port", func(t *testing.T) { rc := &ResolverConfig{ - ExternalNameServers: []string{"127.0.0.53"}, - RootNameServers: []string{"127.0.0.53:53"}, - LocalAddrs: []net.IP{net.ParseIP("127.0.0.1")}, + ExternalNameServersV4: []string{"127.0.0.53"}, + RootNameServersV4: []string{"127.0.0.53:53"}, + LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, } err := rc.Validate() require.NotNil(t, err) }) t.Run("Using root nameserver with no port", func(t *testing.T) { rc := &ResolverConfig{ - ExternalNameServers: []string{"127.0.0.53:53"}, - RootNameServers: []string{"127.0.0.53"}, - LocalAddrs: []net.IP{net.ParseIP("127.0.0.1")}, + ExternalNameServersV4: []string{"127.0.0.53:53"}, + RootNameServersV4: []string{"127.0.0.53"}, + LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, } err := rc.Validate() require.NotNil(t, err) }) t.Run("Missing external nameserver", func(t *testing.T) { rc := &ResolverConfig{ - RootNameServers: []string{"127.0.0.53:53"}, - LocalAddrs: []net.IP{net.ParseIP("127.0.0.1")}, + RootNameServersV4: []string{"127.0.0.53:53"}, + LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, } err := rc.Validate() require.NotNil(t, err) }) t.Run("Missing root nameserver", func(t *testing.T) { rc := &ResolverConfig{ - ExternalNameServers: []string{"127.0.0.53:53"}, - LocalAddrs: []net.IP{net.ParseIP("127.0.0.1")}, + ExternalNameServersV4: []string{"127.0.0.53:53"}, + LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, } err := rc.Validate() require.NotNil(t, err) }) t.Run("Missing local addr", func(t *testing.T) { rc := &ResolverConfig{ - ExternalNameServers: []string{"127.0.0.53:53"}, - RootNameServers: []string{"127.0.0.53:53"}, + ExternalNameServersV4: []string{"127.0.0.53:53"}, + RootNameServersV4: []string{"127.0.0.53:53"}, } err := rc.Validate() require.NotNil(t, err) @@ -76,36 +76,36 @@ func TestResolverConfig_Validate(t *testing.T) { t.Run("Cannot mix loopback addresses in nameservers", func(t *testing.T) { rc := &ResolverConfig{ - ExternalNameServers: []string{"127.0.0.53:53, 1.1.1.1:53"}, - RootNameServers: []string{"127.0.0.53:53"}, - LocalAddrs: []net.IP{net.ParseIP("127.0.0.1")}, + ExternalNameServersV4: []string{"127.0.0.53:53, 1.1.1.1:53"}, + RootNameServersV4: []string{"127.0.0.53:53"}, + LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, } err := rc.Validate() require.NotNil(t, err) }) t.Run("Cannot mix loopback addresses among nameservers", func(t *testing.T) { rc := &ResolverConfig{ - ExternalNameServers: []string{"1.1.1.1:53"}, - RootNameServers: []string{"127.0.0.53:53"}, - LocalAddrs: []net.IP{net.ParseIP("127.0.0.1")}, + ExternalNameServersV4: []string{"1.1.1.1:53"}, + RootNameServersV4: []string{"127.0.0.53:53"}, + LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, } err := rc.Validate() require.NotNil(t, err) }) t.Run("Cannot reach loopback NSes from non-loopback local address", func(t *testing.T) { rc := &ResolverConfig{ - ExternalNameServers: []string{"127.0.0.53:53"}, - RootNameServers: []string{"127.0.0.53:53"}, - LocalAddrs: []net.IP{net.ParseIP("192.168.1.2")}, + ExternalNameServersV4: []string{"127.0.0.53:53"}, + RootNameServersV4: []string{"127.0.0.53:53"}, + LocalAddrsV4: []net.IP{net.ParseIP("192.168.1.2")}, } err := rc.Validate() require.NotNil(t, err) }) t.Run("Cannot reach non-loopback NSes from loopback local address", func(t *testing.T) { rc := &ResolverConfig{ - ExternalNameServers: []string{"1.1.1.1:53"}, - RootNameServers: []string{"1.1.1.1:53"}, - LocalAddrs: []net.IP{net.ParseIP("127.0.0.1")}, + ExternalNameServersV4: []string{"1.1.1.1:53"}, + RootNameServersV4: []string{"1.1.1.1:53"}, + LocalAddrsV4: []net.IP{net.ParseIP("127.0.0.1")}, } err := rc.Validate() require.NotNil(t, err) diff --git a/src/zdns/types.go b/src/zdns/types.go index 6fdb06ff..34fdd7d6 100644 --- a/src/zdns/types.go +++ b/src/zdns/types.go @@ -68,3 +68,27 @@ func (ivm IPVersionMode) IsValid() (bool, string) { } return true, "" } + +type IterationIPPreference int + +const ( + PreferIPv4 IterationIPPreference = iota + PreferIPv6 +) + +func GetIterationIPPreference(preferIPv4, preferIPv6 bool) IterationIPPreference { + if preferIPv4 { + return PreferIPv4 + } else if preferIPv6 { + return PreferIPv6 + } + return PreferIPv4 +} + +func (iip IterationIPPreference) IsValid() (bool, string) { + isValid := iip >= 0 && iip <= 1 + if !isValid { + return false, fmt.Sprintf("invalid iteration ip preference: %d", iip) + } + return true, "" +} diff --git a/src/zdns/util.go b/src/zdns/util.go index 28b67e22..d5571c5d 100644 --- a/src/zdns/util.go +++ b/src/zdns/util.go @@ -19,6 +19,8 @@ import ( "net" "strings" + log "github.com/sirupsen/logrus" + "github.com/zmap/dns" ) @@ -55,23 +57,53 @@ func nameIsBeneath(name, layer string) (bool, string) { return false, "" } -func checkGlue(server string, result SingleQueryResult) (SingleQueryResult, Status) { +func checkGlue(server string, result SingleQueryResult, ipMode IPVersionMode, ipPreference IterationIPPreference) (SingleQueryResult, Status) { + var ansType string + if ipMode == IPv4Only { + ansType = "A" + } else if ipMode == IPv6Only { + ansType = "AAAA" + } else if ipPreference == PreferIPv4 { + // must be using either IPv4 or IPv6 + ansType = "A" + } else if ipPreference == PreferIPv6 { + // must be using either IPv4 or IPv6 + ansType = "AAAA" + } else { + log.Fatal("should never hit this case in check glue: ", ipMode, ipPreference) + } + res, status := checkGlueHelper(server, ansType, result) + if status == StatusNoError || ipMode != IPv4OrIPv6 { + // If we have a valid answer, or we're not looking for both A and AAAA records, return + return res, status + } + // If we're looking for both A and AAAA records, and we didn't find an answer, try the other type + if ansType == "A" { + ansType = "AAAA" + } else { + ansType = "A" + } + return checkGlueHelper(server, ansType, result) +} + +func checkGlueHelper(server, ansType string, result SingleQueryResult) (SingleQueryResult, Status) { for _, additional := range result.Additional { ans, ok := additional.(Answer) if !ok { continue } - if ans.Type == "A" && strings.TrimSuffix(ans.Name, ".") == server { + // sanitize case and trailing dot + // RFC 4343 - states DNS names are case-insensitive + if ans.Type == ansType && strings.EqualFold(strings.TrimSuffix(ans.Name, "."), server) { var retv SingleQueryResult retv.Authorities = make([]interface{}, 0) - retv.Answers = make([]interface{}, 0) + retv.Answers = make([]interface{}, 0, 1) retv.Additional = make([]interface{}, 0) retv.Answers = append(retv.Answers, ans) return retv, StatusNoError } } - var r SingleQueryResult - return r, StatusError + return SingleQueryResult{}, StatusError } func makeVerbosePrefix(depth int) string { @@ -120,11 +152,13 @@ func TranslateDNSErrorCode(err int) Status { } // handleStatus is a helper function to deal with a status and error. Error is only returned if the status is an -// Iterative Timeout +// Iterative Timeout or NoNeededGlueRecord func handleStatus(status Status, err error) (Status, error) { switch status { case StatusIterTimeout: return status, err + case StatusNoNeededGlue: + return status, err case StatusNXDomain: return status, nil case StatusServFail: diff --git a/testing/ipv6_tests.py b/testing/ipv6_tests.py new file mode 100644 index 00000000..3205a7b4 --- /dev/null +++ b/testing/ipv6_tests.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 + +import socket +import subprocess +import json +import unittest + + +class Tests(unittest.TestCase): + maxDiff = None + ZDNS_EXECUTABLE = "./zdns" + + ROOT_A = {"1.2.3.4", "2.3.4.5", "3.4.5.6"} + + ROOT_A_ANSWERS = [{"type": "A", "class": "IN", "answer": x, + "name": "zdns-testing.com"} for x in ROOT_A] + + def run_zdns(self, flags, name, executable=ZDNS_EXECUTABLE): + flags = flags + " --threads=10" + c = f"echo '{name}' | {executable} {flags}" + o = subprocess.check_output(c, shell=True) + return c, json.loads(o.rstrip()) + + def assertSuccess(self, res, cmd): + self.assertEqual(res["status"], "NOERROR", cmd) + + def assertServFail(self, res, cmd): + self.assertEqual(res["status"], "SERVFAIL", cmd) + + def assertEqualAnswers(self, res, correct, cmd, key="answer"): + self.assertIn("answers", res["data"]) + for answer in res["data"]["answers"]: + del answer["ttl"] + a = sorted(res["data"]["answers"], key=lambda x: x[key]) + b = sorted(correct, key=lambda x: x[key]) + helptext = "%s\nExpected:\n%s\n\nActual:\n%s" % (cmd, + json.dumps(b, indent=4), json.dumps(a, indent=4)) + + def test_a_ipv6(self): + c = "A --name-servers=[2001:4860:4860::8888]:53" + name = "zdns-testing.com" + cmd, res = self.run_zdns(c, name) + self.assertSuccess(res, cmd) + self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd) + + def test_ipv6_unreachable(self): + c = "A --iterative --6" + name = "esrg.stanford.edu" + cmd, res = self.run_zdns(c, name) + # esrg.stanford.edu is hosted on NS's that do not have an IPv6 address. Therefore, the lookup won't get sufficient glue records to resolve the query. + self.assertEqual(res["status"], "NONEEDEDGLUE", cmd) + + def test_ipv6_external_lookup_unreachable_nameserver(self): + c = "A --6=true --4=false --name-servers=1.1.1.1" + name = "zdns-testing.com" + try: + cmd, res = self.run_zdns(c, name) + except Exception as e: + return True + self.fail("Should have thrown an exception, shouldn't be able to reach any IPv4 servers while in IPv6 mode") + + def test_ipv4_external_lookup_unreachable_nameserver(self): + c = "A --6=false --4=true --name-servers=2606:4700:4700::1111" + name = "zdns-testing.com" + try: + cmd, res = self.run_zdns(c, name) + except Exception as e: + return True + self.fail("Should have thrown an exception, shouldn't be able to reach any IPv6 servers while in IPv4 mode") + + def test_ipv6_external_lookup_loopback_nameserver(self): + c = "A --6=true --4=false --name-servers=[::1]:53" + name = "zdns-testing.com" + try: + cmd, res = self.run_zdns(c, name) + except Exception as e: + return True + self.fail("Should have thrown an exception, shouldn't be able to use a loopback address as a nameserver in IPv6") + + def test_ipv6_happy_path_external(self): + c = "A --6=true" + name = "zdns-testing.com" + cmd, res = self.run_zdns(c, name) + self.assertSuccess(res, cmd) + self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd) + + def test_ipv6_happy_path_iterative(self): + c = "A --6=true --iterative" + name = "zdns-testing.com" + cmd, res = self.run_zdns(c, name) + self.assertSuccess(res, cmd) + self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd) + + def test_ipv6_happy_path_no_ipv4_iterative(self): + c = "A --6=true --4=false --iterative" + name = "zdns-testing.com" + cmd, res = self.run_zdns(c, name) + self.assertSuccess(res, cmd) + self.assertEqualAnswers(res, self.ROOT_A_ANSWERS, cmd) + + +if __name__ == "__main__": + try: + # Attempt to create an IPv6 socket + socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + except OSError: + print("Error: no IPv6 support on this machine, cannot test IPv6 functionality") + exit(1) + unittest.main()