diff --git a/example_test.go b/example_test.go index f6768d1..4505987 100644 --- a/example_test.go +++ b/example_test.go @@ -155,3 +155,46 @@ func ExampleReader_Networks() { // 2003::/24: Cable/DSL } + +// This example demonstrates how to iterate over all networks in the +// database which are contained within an arbitrary network. +func ExampleReader_NetworksWithin() { + db, err := maxminddb.Open("test-data/test-data/GeoIP2-Connection-Type-Test.mmdb") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + record := struct { + Domain string `maxminddb:"connection_type"` + }{} + + _, network, err := net.ParseCIDR("1.0.0.0/8") + if err != nil { + log.Fatal(err) + } + + networks := db.NetworksWithin(network) + for networks.Next() { + subnet, err := networks.Network(&record) + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s: %s\n", subnet.String(), record.Domain) + } + if networks.Err() != nil { + log.Fatal(networks.Err()) + } + + // Output: + //1.0.0.0/24: Dialup + //1.0.1.0/24: Cable/DSL + //1.0.2.0/23: Dialup + //1.0.4.0/22: Dialup + //1.0.8.0/21: Dialup + //1.0.16.0/20: Dialup + //1.0.32.0/19: Dialup + //1.0.64.0/18: Dialup + //1.0.128.0/17: Dialup + +} diff --git a/reader.go b/reader.go index 0b58e36..7aee19e 100644 --- a/reader.go +++ b/reader.go @@ -249,7 +249,20 @@ func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) { if bitCount == 32 { node = r.ipv4Start } + node, prefixLength := r.traverseTree(ip, node, bitCount) + nodeCount := r.Metadata.NodeCount + if node == nodeCount { + // Record is empty + return 0, prefixLength, ip, nil + } else if node > nodeCount { + return node, prefixLength, ip, nil + } + + return 0, prefixLength, ip, newInvalidDatabaseError("invalid node in search tree") +} + +func (r *Reader) traverseTree(ip net.IP, node uint, bitCount uint) (uint, int) { nodeCount := r.Metadata.NodeCount i := uint(0) @@ -263,14 +276,8 @@ func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) { node = r.nodeReader.readRight(offset) } } - if node == nodeCount { - // Record is empty - return 0, int(i), ip, nil - } else if node > nodeCount { - return node, int(i), ip, nil - } - return 0, int(i), ip, newInvalidDatabaseError("invalid node in search tree") + return node, int(i) } func (r *Reader) retrieveData(pointer uint, result interface{}) error { diff --git a/traverse.go b/traverse.go index cbc78db..fff8466 100644 --- a/traverse.go +++ b/traverse.go @@ -1,6 +1,8 @@ package maxminddb -import "net" +import ( + "net" +) // Internal structure used to keep track of nodes we still need to visit. type netNode struct { @@ -17,6 +19,9 @@ type Networks struct { err error } +var allIPv4 = &net.IPNet{IP: make(net.IP, 4), Mask: net.CIDRMask(0, 32)} +var allIPv6 = &net.IPNet{IP: make(net.IP, 16), Mask: net.CIDRMask(0, 128)} + // Networks returns an iterator that can be used to traverse all networks in // the database. // @@ -24,15 +29,42 @@ type Networks struct { // in an IPv6 database. This iterator will iterate over all of these // locations separately. func (r *Reader) Networks() *Networks { - s := 4 + var networks *Networks if r.Metadata.IPVersion == 6 { - s = 16 + networks = r.NetworksWithin(allIPv6) + } else { + networks = r.NetworksWithin(allIPv4) + } + + return networks +} + +// NetworksWithin returns an iterator that can be used to traverse all networks +// in the database which are contained in a given network. +// +// Please note that a MaxMind DB may map IPv4 networks into several locations +// in an IPv6 database. This iterator will iterate over all of these locations +// separately. +// +// If the provided network is contained within a network in the database, the +// iterator will iterate over exactly one network, the containing network. +func (r *Reader) NetworksWithin(network *net.IPNet) *Networks { + ip := network.IP + prefixLength, _ := network.Mask.Size() + + if r.Metadata.IPVersion == 6 && len(ip) == net.IPv4len { + ip = net.IP.To16(ip) + prefixLength += 96 } + + pointer, bit := r.traverseTree(ip, 0, uint(prefixLength)) return &Networks{ reader: r, nodes: []netNode{ { - ip: make(net.IP, s), + ip: ip, + bit: uint(bit), + pointer: pointer, }, }, } diff --git a/traverse_test.go b/traverse_test.go index 41b9847..7ba1af1 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -2,6 +2,7 @@ package maxminddb import ( "fmt" + "net" "testing" "github.com/stretchr/testify/assert" @@ -46,3 +47,192 @@ func TestNetworksWithInvalidSearchTree(t *testing.T) { assert.NotNil(t, n.Err(), "no error received when traversing an broken search tree") assert.Equal(t, n.Err().Error(), "invalid search tree at 128.128.128.128/32") } + +type networkTest struct { + Network string + Database string + Expected []string +} + +var tests = []networkTest{ + networkTest{ + Network: "0.0.0.0/0", + Database: "ipv4", + Expected: []string{ + "1.1.1.1/32", + "1.1.1.2/31", + "1.1.1.4/30", + "1.1.1.8/29", + "1.1.1.16/28", + "1.1.1.32/32", + }, + }, + networkTest{ + Network: "1.1.1.1/30", + Database: "ipv4", + Expected: []string{ + "1.1.1.1/32", + "1.1.1.2/31", + }, + }, + networkTest{ + Network: "1.1.1.1/32", + Database: "ipv4", + Expected: []string{ + "1.1.1.1/32", + }, + }, + networkTest{ + Network: "255.255.255.0/24", + Database: "ipv4", + Expected: []string(nil), + }, + networkTest{ + Network: "1.1.1.1/32", + Database: "mixed", + Expected: []string{ + "1.1.1.1/32", + }, + }, + networkTest{ + Network: "255.255.255.0/24", + Database: "mixed", + Expected: []string(nil), + }, + networkTest{ + Network: "::1:ffff:ffff/128", + Database: "ipv6", + Expected: []string{ + "::1:ffff:ffff/128", + }, + }, + networkTest{ + Network: "::/0", + Database: "ipv6", + Expected: []string{ + "::1:ffff:ffff/128", + "::2:0:0/122", + "::2:0:40/124", + "::2:0:50/125", + "::2:0:58/127", + }, + }, + networkTest{ + Network: "::2:0:40/123", + Database: "ipv6", + Expected: []string{ + "::2:0:40/124", + "::2:0:50/125", + "::2:0:58/127", + }, + }, + networkTest{ + Network: "0:0:0:0:0:ffff:ffff:ff00/120", + Database: "ipv6", + Expected: []string(nil), + }, + networkTest{ + Network: "0.0.0.0/0", + Database: "mixed", + Expected: []string{ + "1.1.1.1/32", + "1.1.1.2/31", + "1.1.1.4/30", + "1.1.1.8/29", + "1.1.1.16/28", + "1.1.1.32/32", + }, + }, + networkTest{ + Network: "1.1.1.16/28", + Database: "mixed", + Expected: []string{ + "1.1.1.16/28", + }, + }, + networkTest{ + Network: "::/0", + Database: "ipv4", + Expected: []string{ + "101:101::/32", + "101:102::/31", + "101:104::/30", + "101:108::/29", + "101:110::/28", + "101:120::/32", + }, + }, + networkTest{ + Network: "101:104::/30", + Database: "ipv4", + Expected: []string{ + "101:104::/30", + }, + }, +} + +func TestNetworksWithin(t *testing.T) { + for _, v := range tests { + for _, recordSize := range []uint{24, 28, 32} { + fileName := testFile(fmt.Sprintf("MaxMind-DB-test-%s-%d.mmdb", v.Database, recordSize)) + reader, err := Open(fileName) + require.Nil(t, err, "unexpected error while opening database: %v", err) + defer reader.Close() + + _, network, err := net.ParseCIDR(v.Network) + assert.Nil(t, err) + n := reader.NetworksWithin(network) + var innerIPs []string + + for n.Next() { + record := struct { + IP string `maxminddb:"ip"` + }{} + network, err := n.Network(&record) + assert.Nil(t, err) + innerIPs = append(innerIPs, network.String()) + } + + assert.Equal(t, v.Expected, innerIPs) + assert.Nil(t, n.Err()) + } + } +} + +var geoIPTests = []networkTest{ + networkTest{ + Network: "81.2.69.128/26", + Database: "GeoIP2-Country-Test.mmdb", + Expected: []string{ + "81.2.69.142/31", + "81.2.69.144/28", + "81.2.69.160/27", + }, + }, +} + +func TestGeoIPNetworksWithin(t *testing.T) { + for _, v := range geoIPTests { + fileName := testFile(v.Database) + reader, err := Open(fileName) + require.Nil(t, err, "unexpected error while opening database: %v", err) + defer reader.Close() + + _, network, err := net.ParseCIDR(v.Network) + assert.Nil(t, err) + n := reader.NetworksWithin(network) + var innerIPs []string + + for n.Next() { + record := struct { + IP string `maxminddb:"ip"` + }{} + network, err := n.Network(&record) + assert.Nil(t, err) + innerIPs = append(innerIPs, network.String()) + } + + assert.Equal(t, v.Expected, innerIPs) + assert.Nil(t, n.Err()) + } +}