Skip to content

Commit

Permalink
all: remove redundant net.IP <-> string converions
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Jan 18, 2021
1 parent 5df8af0 commit 7fd0634
Show file tree
Hide file tree
Showing 22 changed files with 167 additions and 157 deletions.
1 change: 1 addition & 0 deletions internal/dnsfilter/dnsfilter.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type RequestFilteringSettings struct {
ParentalEnabled bool

ClientName string
// TODO(e.burkov): wait for urlfilter update to replace with net.IP.
ClientIP string
ClientTags []string

Expand Down
14 changes: 6 additions & 8 deletions internal/dnsforward/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,19 @@ func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []strin
// Returns the item from the "disallowedClients" list that lead to blocking IP.
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty,
// but the ip does not belong to it.
func (a *accessCtx) IsBlockedIP(ip string) (bool, string) {
func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) {
a.lock.Lock()
defer a.lock.Unlock()

if len(a.allowedClients) != 0 || len(a.allowedClientsIPNet) != 0 {
_, ok := a.allowedClients[ip]
_, ok := a.allowedClients[ip.String()]
if ok {
return false, ""
}

if len(a.allowedClientsIPNet) != 0 {
ipAddr := net.ParseIP(ip)
for _, ipnet := range a.allowedClientsIPNet {
if ipnet.Contains(ipAddr) {
if ipnet.Contains(ip) {
return false, ""
}
}
Expand All @@ -105,15 +104,14 @@ func (a *accessCtx) IsBlockedIP(ip string) (bool, string) {
return true, ""
}

_, ok := a.disallowedClients[ip]
_, ok := a.disallowedClients[ip.String()]
if ok {
return true, ip
return true, ip.String()
}

if len(a.disallowedClientsIPNet) != 0 {
ipAddr := net.ParseIP(ip)
for _, ipnet := range a.disallowedClientsIPNet {
if ipnet.Contains(ipAddr) {
if ipnet.Contains(ip) {
return true, ipnet.String()
}
}
Expand Down
17 changes: 9 additions & 8 deletions internal/dnsforward/access_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dnsforward

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -10,19 +11,19 @@ func TestIsBlockedIPAllowed(t *testing.T) {
a := &accessCtx{}
assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil))

disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
assert.False(t, disallowed)
assert.Empty(t, disallowedRule)

disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
assert.True(t, disallowed)
assert.Empty(t, disallowedRule)

disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
assert.False(t, disallowed)
assert.Empty(t, disallowedRule)

disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
assert.True(t, disallowed)
assert.Empty(t, disallowedRule)
}
Expand All @@ -31,19 +32,19 @@ func TestIsBlockedIPDisallowed(t *testing.T) {
a := &accessCtx{}
assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil))

disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
assert.True(t, disallowed)
assert.Equal(t, "1.1.1.1", disallowedRule)

disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
assert.False(t, disallowed)
assert.Empty(t, disallowedRule)

disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
assert.True(t, disallowed)
assert.Equal(t, "2.2.0.0/16", disallowedRule)

disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
assert.False(t, disallowed)
assert.Empty(t, disallowedRule)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/dnsforward/dnsforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// IsBlockedIP - return TRUE if this client should be blocked
func (s *Server) IsBlockedIP(ip string) (bool, string) {
func (s *Server) IsBlockedIP(ip net.IP) (bool, string) {
return s.access.IsBlockedIP(ip)
}
2 changes: 1 addition & 1 deletion internal/dnsforward/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := IPStringFromAddr(d.Addr)
ip := IPFromAddr(d.Addr)
disallowed, _ := s.access.IsBlockedIP(ip)
if disallowed {
log.Tracef("Client IP %s is blocked by settings", ip)
Expand Down
2 changes: 1 addition & 1 deletion internal/dnsforward/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func processQueryLogsAndStats(ctx *dnsContext) int {
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: ipFromAddr(d.Addr),
ClientIP: IPFromAddr(d.Addr),
}

switch d.Proto {
Expand Down
6 changes: 3 additions & 3 deletions internal/dnsforward/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
"github.com/AdguardTeam/golibs/utils"
)

// ipFromAddr gets IP address from addr.
func ipFromAddr(addr net.Addr) (ip net.IP) {
// IPFromAddr gets IP address from addr.
func IPFromAddr(addr net.Addr) (ip net.IP) {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP
Expand All @@ -23,7 +23,7 @@ func ipFromAddr(addr net.Addr) (ip net.IP) {
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func IPStringFromAddr(addr net.Addr) (ipstr string) {
if ip := ipFromAddr(addr); ip != nil {
if ip := IPFromAddr(addr); ip != nil {
return ip.String()
}

Expand Down
41 changes: 20 additions & 21 deletions internal/home/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ type ClientHost struct {
}

type clientsContainer struct {
list map[string]*Client // name -> client
idIndex map[string]*Client // IP -> client
ipHost map[string]*ClientHost // IP -> Hostname
lock sync.Mutex
list map[string]*Client // name -> client
idIndex map[string]*Client // IP -> client
// TODO(e.burkov): think about using maphash.
ipHost map[string]*ClientHost // IP -> Hostname
lock sync.Mutex

allTags map[string]bool

Expand Down Expand Up @@ -239,7 +240,7 @@ func (clients *clientsContainer) onHostsChanged() {
}

// Exists checks if client with this IP already exists
func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) bool {
clients.lock.Lock()
defer clients.lock.Unlock()

Expand All @@ -248,7 +249,7 @@ func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
return true
}

ch, ok := clients.ipHost[ip]
ch, ok := clients.ipHost[ip.String()]
if !ok {
return false
}
Expand All @@ -265,7 +266,7 @@ func stringArrayDup(a []string) []string {
}

// Find searches for a client by IP
func (clients *clientsContainer) Find(ip string) (Client, bool) {
func (clients *clientsContainer) Find(ip net.IP) (Client, bool) {
clients.lock.Lock()
defer clients.lock.Unlock()

Expand All @@ -287,7 +288,7 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
clients.lock.Lock()
defer clients.lock.Unlock()

c, ok := clients.findByIP(ip)
c, ok := clients.findByIP(net.ParseIP(ip))
if !ok {
return nil
}
Expand All @@ -307,13 +308,12 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
}

// Find searches for a client by IP (and does not lock anything)
func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
func (clients *clientsContainer) findByIP(ip net.IP) (Client, bool) {
if ip == nil {
return Client{}, false
}

c, ok := clients.idIndex[ip]
c, ok := clients.idIndex[ip.String()]
if ok {
return *c, true
}
Expand All @@ -324,7 +324,7 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
if err != nil {
continue
}
if ipnet.Contains(ipAddr) {
if ipnet.Contains(ip) {
return *c, true
}
}
Expand All @@ -333,7 +333,7 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
if clients.dhcpServer == nil {
return Client{}, false
}
macFound := clients.dhcpServer.FindMACbyIP(ipAddr)
macFound := clients.dhcpServer.FindMACbyIP(ip)
if macFound == nil {
return Client{}, false
}
Expand All @@ -353,16 +353,15 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
}

// FindAutoClient - search for an auto-client by IP
func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
func (clients *clientsContainer) FindAutoClient(ip net.IP) (ClientHost, bool) {
if ip == nil {
return ClientHost{}, false
}

clients.lock.Lock()
defer clients.lock.Unlock()

ch, ok := clients.ipHost[ip]
ch, ok := clients.ipHost[ip.String()]
if ok {
return *ch, true
}
Expand Down Expand Up @@ -539,7 +538,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
}

// SetWhoisInfo - associate WHOIS information with a client
func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
func (clients *clientsContainer) SetWhoisInfo(ip net.IP, info [][]string) {
clients.lock.Lock()
defer clients.lock.Unlock()

Expand All @@ -549,7 +548,7 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
return
}

ch, ok := clients.ipHost[ip]
ch, ok := clients.ipHost[ip.String()]
if ok {
ch.WhoisInfo = info
log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo)
Expand All @@ -561,7 +560,7 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
Source: ClientSourceWHOIS,
}
ch.WhoisInfo = info
clients.ipHost[ip] = ch
clients.ipHost[ip.String()] = ch
log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo)
}

Expand Down
28 changes: 14 additions & 14 deletions internal/home/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,21 @@ func TestClients(t *testing.T) {
assert.True(t, b)
assert.Nil(t, err)

c, b = clients.Find("1.1.1.1")
c, b = clients.Find(net.IPv4(1, 1, 1, 1))
assert.True(t, b)
assert.Equal(t, c.Name, "client1")

c, b = clients.Find("1:2:3::4")
c, b = clients.Find(net.ParseIP("1:2:3::4"))
assert.True(t, b)
assert.Equal(t, c.Name, "client1")

c, b = clients.Find("2.2.2.2")
c, b = clients.Find(net.IPv4(2, 2, 2, 2))
assert.True(t, b)
assert.Equal(t, c.Name, "client2")

assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
assert.False(t, clients.Exists(net.IPv4(1, 2, 3, 4), ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IPv4(2, 2, 2, 2), ClientSourceHostsFile))
})

t.Run("add_fail_name", func(t *testing.T) {
Expand Down Expand Up @@ -112,8 +112,8 @@ func TestClients(t *testing.T) {
err := clients.Update("client1", c)
assert.Nil(t, err)

assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile))

c = Client{
IDs: []string{"1.1.1.2"},
Expand All @@ -124,7 +124,7 @@ func TestClients(t *testing.T) {
err = clients.Update("client1", c)
assert.Nil(t, err)

c, b := clients.Find("1.1.1.2")
c, b := clients.Find(net.IPv4(1, 1, 1, 2))
assert.True(t, b)
assert.Equal(t, "client1-renamed", c.Name)
assert.Equal(t, "1.1.1.2", c.IDs[0])
Expand All @@ -135,7 +135,7 @@ func TestClients(t *testing.T) {
t.Run("del_success", func(t *testing.T) {
b := clients.Del("client1-renamed")
assert.True(t, b)
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile))
})

t.Run("del_fail", func(t *testing.T) {
Expand All @@ -156,7 +156,7 @@ func TestClients(t *testing.T) {
assert.True(t, b)
assert.Nil(t, err)

assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
})

t.Run("addhost_fail", func(t *testing.T) {
Expand All @@ -174,12 +174,12 @@ func TestClientsWhois(t *testing.T) {

whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
// set whois info on new client
clients.SetWhoisInfo("1.1.1.255", whois)
clients.SetWhoisInfo(net.IPv4(1, 1, 1, 255), whois)
assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.255"].WhoisInfo[0][1])

// set whois info on existing auto-client
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
clients.SetWhoisInfo("1.1.1.1", whois)
clients.SetWhoisInfo(net.IPv4(1, 1, 1, 1), whois)
assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.1"].WhoisInfo[0][1])

// Check that we cannot set whois info on a manually-added client
Expand All @@ -188,7 +188,7 @@ func TestClientsWhois(t *testing.T) {
Name: "client1",
}
_, _ = clients.Add(c)
clients.SetWhoisInfo("1.1.1.2", whois)
clients.SetWhoisInfo(net.IPv4(1, 1, 1, 2), whois)
assert.Nil(t, clients.ipHost["1.1.1.2"])
_ = clients.Del("client1")
}
Expand Down
Loading

0 comments on commit 7fd0634

Please sign in to comment.