Skip to content

Commit

Permalink
Pass domains as NameServers and NameServer struct re-factor (#435)
Browse files Browse the repository at this point in the history
* examples pass and code compiles

* fixed up modules

* module tests

* module tests working

* forgot to commit modules

* unit tests passing

* handle empty nameserver correctly

* added new unit tests

* added two negative test cases

* added integration test for new domain nameserver feature

* lint

* fixed (hopefully) host IP capability detection

* fix for metadata integration test when running in parrallel

* make file names unique for test parallelism

* use zdns defaults for name-server-mode

* review
  • Loading branch information
phillip-stephens authored Sep 5, 2024
1 parent cf164e2 commit e11da0b
Show file tree
Hide file tree
Showing 27 changed files with 743 additions and 519 deletions.
4 changes: 2 additions & 2 deletions examples/multi_thread_lookup/multi_threaded.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ func initializeResolver(cache *zdns.Cache) *zdns.Resolver {
log.Fatal("Error getting local IP: ", err)
}
resolverConfig.LocalAddrsV4 = []net.IP{localAddr}
resolverConfig.ExternalNameServersV4 = []string{"1.1.1.1:53"}
resolverConfig.RootNameServersV4 = []string{"198.41.0.4:53"}
resolverConfig.ExternalNameServersV4 = []zdns.NameServer{{IP: net.ParseIP("1.1.1.1"), Port: 53}}
resolverConfig.RootNameServersV4 = []zdns.NameServer{{IP: net.ParseIP("198.41.0.4"), Port: 53}}
resolverConfig.IPVersionMode = zdns.IPv4Only
// Set any desired options on the ResolverConfig object
resolverConfig.Cache = cache
Expand Down
6 changes: 3 additions & 3 deletions examples/single_lookup/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func main() {
dnsQuestion := &zdns.Question{Name: domain, Type: dns.TypeA, Class: dns.ClassINET}
resolver := initializeResolver()

result, _, status, err := resolver.ExternalLookup(dnsQuestion, "1.1.1.1:53")
result, _, status, err := resolver.ExternalLookup(dnsQuestion, &zdns.NameServer{IP: net.ParseIP("1.1.1.1"), Port: 53})
if err != nil {
log.Fatal("Error looking up domain: ", err)
}
Expand Down Expand Up @@ -67,8 +67,8 @@ func initializeResolver() *zdns.Resolver {
// Set any desired options on the ResolverConfig object
resolverConfig.LogLevel = log.InfoLevel
resolverConfig.LocalAddrsV4 = []net.IP{localAddr}
resolverConfig.ExternalNameServersV4 = []string{"1.1.1.1:53"}
resolverConfig.RootNameServersV4 = []string{"198.41.0.4:53"}
resolverConfig.ExternalNameServersV4 = []zdns.NameServer{{IP: net.ParseIP("1.1.1.1"), Port: 53}}
resolverConfig.RootNameServersV4 = []zdns.NameServer{{IP: net.ParseIP("198.41.0.4"), Port: 53}}
resolverConfig.IPVersionMode = zdns.IPv4Only
// Create a new Resolver object with the ResolverConfig object, it will retain all settings set on the ResolverConfig object
resolver, err := zdns.InitResolver(resolverConfig)
Expand Down
4 changes: 2 additions & 2 deletions src/cli/modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (

type LookupModule interface {
CLIInit(gc *CLIConf, rc *zdns.ResolverConfig) error
Lookup(resolver *zdns.Resolver, lookupName, nameServer string) (interface{}, zdns.Trace, zdns.Status, error)
Lookup(resolver *zdns.Resolver, lookupName string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error)
Help() string // needed to satisfy the ZCommander interface in ZFlags.
GetDescription() string // needed to add a command to the parser, printed to the user. Printed to the user when they run the help command for a given module
Validate(args []string) error // needed to satisfy the ZCommander interface in ZFlags
Expand Down Expand Up @@ -165,7 +165,7 @@ func (lm *BasicLookupModule) NewFlags() interface{} {
return lm
}

func (lm *BasicLookupModule) Lookup(resolver *zdns.Resolver, lookupName, nameServer string) (interface{}, zdns.Trace, zdns.Status, error) {
func (lm *BasicLookupModule) Lookup(resolver *zdns.Resolver, lookupName string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error) {
if lm.LookupAllNameServers {
return resolver.LookupAllNameservers(&zdns.Question{Name: lookupName, Type: lm.DNSType, Class: lm.DNSClass}, nameServer)
}
Expand Down
285 changes: 184 additions & 101 deletions src/cli/worker_manager.go

Large diffs are not rendered by default.

165 changes: 165 additions & 0 deletions src/cli/worker_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* ZDNS Copyright 2022 Regents of the University of Michigan
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy
* of the License at http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package cli

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"

"github.com/zmap/zdns/src/zdns"
)

func TestConvertNameServerStringToNameServer(t *testing.T) {
tests := []struct {
nameServerString string
expectedNameServer string
}{
{
"1.1.1.1",
"1.1.1.1:53",
}, {
"1.1.1.1:35",
"1.1.1.1:35",
}, {
"2606:4700:4700::1111",
"[2606:4700:4700::1111]:53",
}, {
"[2606:4700:4700::1111]:35",
"[2606:4700:4700::1111]:35",
},
}
for _, test := range tests {
nses, err := convertNameServerStringToNameServer(test.nameServerString, zdns.IPv4OrIPv6)
require.Nil(t, err)
require.Len(t, nses, 1)
if nses[0].String() != test.expectedNameServer {
t.Errorf("Expected %s, got %s", test.expectedNameServer, nses[0].String())
}
}
// need to convert these to use the .String method to test
t.Run("Domain Name as Name Server, both IPv4 and v6", func(t *testing.T) {
nses, err := convertNameServerStringToNameServer("one.one.one.one", zdns.IPv4OrIPv6)
require.Nil(t, err)
expectedNSes := []string{"1.1.1.1:53", "1.0.0.1:53", "[2606:4700:4700::1111]:53", "[2606:4700:4700::1001]:53"}
containsExpectedNameServerStrings(t, nses, expectedNSes)
})
t.Run("Domain Name as Name Server, just IPv4", func(t *testing.T) {
nses, err := convertNameServerStringToNameServer("one.one.one.one", zdns.IPv4Only)
require.Nil(t, err)
expectedNSes := []string{"1.1.1.1:53", "1.0.0.1:53"}
containsExpectedNameServerStrings(t, nses, expectedNSes)
})
t.Run("Domain Name as Name Server, just IPv6", func(t *testing.T) {
nses, err := convertNameServerStringToNameServer("one.one.one.one", zdns.IPv6Only)
require.Nil(t, err)
expectedNSes := []string{"[2606:4700:4700::1111]:53", "[2606:4700:4700::1001]:53"}
containsExpectedNameServerStrings(t, nses, expectedNSes)
})
t.Run("Domain Name as Name Server, port provided", func(t *testing.T) {
nses, err := convertNameServerStringToNameServer("one.one.one.one:2345", zdns.IPv4OrIPv6)
require.Nil(t, err)
expectedNSes := []string{"1.1.1.1:2345", "1.0.0.1:2345", "[2606:4700:4700::1111]:2345", "[2606:4700:4700::1001]:2345"}
containsExpectedNameServerStrings(t, nses, expectedNSes)
})
t.Run("Bad domain name", func(t *testing.T) {
_, err := convertNameServerStringToNameServer("bad.domain.name", zdns.IPv4OrIPv6)
require.Error(t, err)
})
t.Run("Bad IP address", func(t *testing.T) {
_, err := convertNameServerStringToNameServer("1.1.1.556", zdns.IPv4OrIPv6)
require.Error(t, err)
})
}

func containsExpectedNameServerStrings(t *testing.T, actualNSes []zdns.NameServer, expectedNameServers []string) {
require.Len(t, actualNSes, len(expectedNameServers))
currentNS := ""
var foundNS bool
for _, ns := range expectedNameServers {
currentNS = ns
foundNS = false
for _, actualNS := range actualNSes {
if actualNS.String() == ns {
foundNS = true
break
}
}
if !foundNS {
require.Fail(t, fmt.Sprintf("Expected nameserver %s not present in actual list", currentNS))
}
}
}

func TestRemoveDomainsFromNameServersString(t *testing.T) {
tests := []struct {
input string
expected []string
}{
// Test with no name servers (empty list)
{
input: "",
expected: []string{},
},
// Test with single IP only
{
input: "1.1.1.1",
expected: []string{"1.1.1.1"},
},
// Test with single domain only
{
input: "example.com",
expected: []string{},
},
// Test with single IP+Port
{
input: "1.1.1.1:53",
expected: []string{"1.1.1.1:53"},
},
// Test with two IPs
{
input: "1.1.1.1,8.8.8.8",
expected: []string{"1.1.1.1", "8.8.8.8"},
},
// Test with IP and domain
{
input: "1.1.1.1,example.com",
expected: []string{"1.1.1.1"},
},
// Test with IP, IP+Port, and domain
{
input: "1.1.1.1,example.com,8.8.8.8:53",
expected: []string{"1.1.1.1", "8.8.8.8:53"},
},
// Test with IPv6, domain, and IPv4
{
input: "2001:4860:4860::8888,example.com,8.8.8.8",
expected: []string{"2001:4860:4860::8888", "8.8.8.8"},
},
// Test with IPv6+Port and domain
{
input: "[2001:4860:4860::8888]:53,example.com",
expected: []string{"[2001:4860:4860::8888]:53"},
},
}

for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
result := removeDomainsFromNameServersString(test.input)
require.Equal(t, result, test.expected)
})
}
}
37 changes: 0 additions & 37 deletions src/internal/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package util

import (
"context"
"fmt"
"net"
"regexp"
"strconv"
Expand Down Expand Up @@ -70,30 +69,6 @@ func SplitHostPort(inaddr string) (net.IP, int, error) {
return ip, portInt, nil
}

// SplitIPv4AndIPv6Addrs splits a list of IP addresses (either with port attached or not) into IPv4 and IPv6 addresses.
// Returns a slice of IPv4/IPv6 addresses that are guaranteed to be valid. If the port was attached, it'll be included.
func SplitIPv4AndIPv6Addrs(addrs []string) (ipv4 []string, ipv6 []string, err error) {
for _, addr := range addrs {
ip, _, err := SplitHostPort(addr)
if err != nil {
// addr may be an IP without a port
ip = net.ParseIP(addr)
}
if ip == nil {
return nil, nil, fmt.Errorf("invalid IP address: %s", addr)
}
// ip is valid, check if it's IPv4 or IPv6
if ip.To4() != nil {
ipv4 = append(ipv4, addr)
} else if ip.To16() != nil {
ipv6 = append(ipv6, addr)
} else {
return nil, nil, fmt.Errorf("invalid IP address: %s", addr)
}
}
return ipv4, ipv6, nil
}

// IsStringValidDomainName checks if the given string is a valid domain name using regex
func IsStringValidDomainName(domain string) bool {
var domainRegex = regexp.MustCompile(`^(?i)[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)*\.[a-z]{2,}$`)
Expand Down Expand Up @@ -122,18 +97,6 @@ func Contains[T comparable](slice []T, entity T) bool {
return false
}

func RemoveDuplicates[T comparable](slice []T) []T {
lookup := make(map[T]struct{}, len(slice)) // prealloc for performance
result := make([]T, 0, len(slice))
for _, v := range slice {
if _, ok := lookup[v]; !ok {
lookup[v] = struct{}{}
result = append(result, v)
}
}
return result
}

// Concat returns a new slice concatenating the passed in slices.
//
// Avoids a gotcha in Go where since append modifies the underlying memory of the input slice, doing
Expand Down
36 changes: 0 additions & 36 deletions src/internal/util/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,42 +108,6 @@ func TestContains(t *testing.T) {
}
}

func TestRemoveDuplicates(t *testing.T) {
tests := []struct {
name string
input []int
expected []int
}{
{
name: "No duplicates",
input: []int{1, 2, 3, 4, 5},
expected: []int{1, 2, 3, 4, 5},
},
{
name: "All duplicates",
input: []int{1, 1, 1, 1, 1},
expected: []int{1},
},
{
name: "Some duplicates",
input: []int{1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5},
expected: []int{1, 2, 3, 4, 5},
},
{
name: "Empty slice",
input: []int{},
expected: []int{},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := RemoveDuplicates(test.input)
require.Equal(t, test.expected, result)
})
}
}

func TestConcat(t *testing.T) {
inputSlice1 := make([]int, 0, 10)
for i := 0; i < 3; i++ {
Expand Down
2 changes: 1 addition & 1 deletion src/modules/alookup/a_lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (aMod *ALookupModule) Init(ipv4Lookup bool, ipv6Lookup bool) {
aMod.IPv6Lookup = ipv6Lookup
}

func (aMod *ALookupModule) Lookup(r *zdns.Resolver, lookupName, nameServer string) (interface{}, zdns.Trace, zdns.Status, error) {
func (aMod *ALookupModule) Lookup(r *zdns.Resolver, lookupName string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error) {
ipResult, trace, status, err := r.DoTargetedLookup(lookupName, nameServer, aMod.baseModule.IsIterative, aMod.IPv4Lookup, aMod.IPv6Lookup)
return ipResult, trace, status, err
}
Expand Down
15 changes: 8 additions & 7 deletions src/modules/axfr/axfr.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ type TransferClient struct {
dns.Transfer
}

func (axfrMod *AxfrLookupModule) doAXFR(name, server string) AXFRServerResult {
func (axfrMod *AxfrLookupModule) doAXFR(name string, server *zdns.NameServer) AXFRServerResult {
var retv AXFRServerResult
retv.Server = server
retv.Server = server.IP.String()
// check if the server address is blacklisted and if so, exclude
if axfrMod.Blacklist != nil {
if blacklisted, err := axfrMod.Blacklist.IsBlacklisted(server); err != nil {
if blacklisted, err := axfrMod.Blacklist.IsBlacklisted(server.IP.String()); err != nil {
retv.Status = zdns.StatusError
retv.Error = "blacklist-error"
return retv
Expand All @@ -79,7 +79,7 @@ func (axfrMod *AxfrLookupModule) doAXFR(name, server string) AXFRServerResult {
}
m := new(dns.Msg)
m.SetAxfr(dotName(name))
if a, err := axfrMod.In(m, net.JoinHostPort(server, "53")); err != nil {
if a, err := axfrMod.In(m, net.JoinHostPort(server.IP.String(), "53")); err != nil {
retv.Status = zdns.StatusError
retv.Error = err.Error()
return retv
Expand All @@ -101,9 +101,9 @@ func (axfrMod *AxfrLookupModule) doAXFR(name, server string) AXFRServerResult {
return retv
}

func (axfrMod *AxfrLookupModule) Lookup(resolver *zdns.Resolver, name, nameServer string) (interface{}, zdns.Trace, zdns.Status, error) {
func (axfrMod *AxfrLookupModule) Lookup(resolver *zdns.Resolver, name string, nameServer *zdns.NameServer) (interface{}, zdns.Trace, zdns.Status, error) {
var retv AXFRResult
if nameServer == "" {
if nameServer == nil {
parsedNS, trace, status, err := axfrMod.NSModule.Lookup(resolver, name, nameServer)
if status != zdns.StatusNoError {
return nil, trace, status, err
Expand All @@ -114,7 +114,8 @@ func (axfrMod *AxfrLookupModule) Lookup(resolver *zdns.Resolver, name, nameServe
}
for _, server := range castedNS.Servers {
if len(server.IPv4Addresses) > 0 {
retv.Servers = append(retv.Servers, axfrMod.doAXFR(name, server.IPv4Addresses[0]))
ns := &zdns.NameServer{IP: net.ParseIP(server.IPv4Addresses[0])}
retv.Servers = append(retv.Servers, axfrMod.doAXFR(name, ns))
}
}
} else {
Expand Down
Loading

0 comments on commit e11da0b

Please sign in to comment.