diff --git a/traverse.go b/traverse.go index faf2e70..90073e2 100644 --- a/traverse.go +++ b/traverse.go @@ -96,9 +96,12 @@ func (r *Reader) NetworksWithin(network *net.IPNet, options ...NetworksOption) * pointer, bit := r.traverseTree(ip, 0, uint(prefixLength)) - if bit < prefixLength { - ip = ip.Mask(net.CIDRMask(bit, len(ip)*8)) - } + // We could skip this when bit >= prefixLength if we assume that the network + // passed in is in canonical form. However, given that this may not be the + // case, it is safest to always take the mask. If this is hot code at some + // point, we could eliminate the allocation of the net.IPMask by zeroing + // out the bits in ip directly. + ip = ip.Mask(net.CIDRMask(bit, len(ip)*8)) networks.nodes = []netNode{ { ip: ip, diff --git a/traverse_test.go b/traverse_test.go index 0248243..00edfce 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -3,6 +3,8 @@ package maxminddb import ( "fmt" "net" + "strconv" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -71,6 +73,8 @@ var tests = []networkTest{ }, }, { + // This is intentionally in non-canonical form to test + // that we handle it correctly. Network: "1.1.1.1/30", Database: "ipv4", Expected: []string{ @@ -78,6 +82,13 @@ var tests = []networkTest{ "1.1.1.2/31", }, }, + { + Network: "1.1.1.2/31", + Database: "ipv4", + Expected: []string{ + "1.1.1.2/31", + }, + }, { Network: "1.1.1.1/32", Database: "ipv4", @@ -267,7 +278,21 @@ func TestNetworksWithin(t *testing.T) { reader, err := Open(fileName) require.NoError(t, err, "unexpected error while opening database: %v", err) - _, network, err := net.ParseCIDR(v.Network) + // We are purposely not using net.ParseCIDR so that we can pass in + // values that aren't in canonical form. + parts := strings.Split(v.Network, "/") + ip := net.ParseIP(parts[0]) + if v := ip.To4(); v != nil { + ip = v + } + prefixLength, err := strconv.Atoi(parts[1]) + require.NoError(t, err) + mask := net.CIDRMask(prefixLength, len(ip)*8) + network := &net.IPNet{ + IP: ip, + Mask: mask, + } + require.NoError(t, err) n := reader.NetworksWithin(network, v.Options...) var innerIPs []string