diff --git a/internal/dnsfilter/dnsfilter.go b/internal/dnsfilter/dnsfilter.go index c5c28aff722..6d8e2d26b34 100644 --- a/internal/dnsfilter/dnsfilter.go +++ b/internal/dnsfilter/dnsfilter.go @@ -36,6 +36,7 @@ type RequestFilteringSettings struct { ParentalEnabled bool ClientName string + // TODO(e.burkov): wait for urlfilter update to replace with net.IP. ClientIP string ClientTags []string diff --git a/internal/dnsforward/access.go b/internal/dnsforward/access.go index 5038a89ac00..20c785e5070 100644 --- a/internal/dnsforward/access.go +++ b/internal/dnsforward/access.go @@ -83,20 +83,19 @@ func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []strin // Returns the item from the "disallowedClients" list that lead to blocking IP. // If it returns TRUE and an empty string, it means that the "allowedClients" is not empty, // but the ip does not belong to it. -func (a *accessCtx) IsBlockedIP(ip string) (bool, string) { +func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) { a.lock.Lock() defer a.lock.Unlock() if len(a.allowedClients) != 0 || len(a.allowedClientsIPNet) != 0 { - _, ok := a.allowedClients[ip] + _, ok := a.allowedClients[ip.String()] if ok { return false, "" } if len(a.allowedClientsIPNet) != 0 { - ipAddr := net.ParseIP(ip) for _, ipnet := range a.allowedClientsIPNet { - if ipnet.Contains(ipAddr) { + if ipnet.Contains(ip) { return false, "" } } @@ -105,15 +104,14 @@ func (a *accessCtx) IsBlockedIP(ip string) (bool, string) { return true, "" } - _, ok := a.disallowedClients[ip] + _, ok := a.disallowedClients[ip.String()] if ok { - return true, ip + return true, ip.String() } if len(a.disallowedClientsIPNet) != 0 { - ipAddr := net.ParseIP(ip) for _, ipnet := range a.disallowedClientsIPNet { - if ipnet.Contains(ipAddr) { + if ipnet.Contains(ip) { return true, ipnet.String() } } diff --git a/internal/dnsforward/access_test.go b/internal/dnsforward/access_test.go index 5c225b21e98..af13b02e459 100644 --- a/internal/dnsforward/access_test.go +++ b/internal/dnsforward/access_test.go @@ -1,6 +1,7 @@ package dnsforward import ( + "net" "testing" "github.com/stretchr/testify/assert" @@ -10,19 +11,19 @@ func TestIsBlockedIPAllowed(t *testing.T) { a := &accessCtx{} assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil)) - disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1") + disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1)) assert.False(t, disallowed) assert.Empty(t, disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2)) assert.True(t, disallowed) assert.Empty(t, disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1)) assert.False(t, disallowed) assert.Empty(t, disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1)) assert.True(t, disallowed) assert.Empty(t, disallowedRule) } @@ -31,19 +32,19 @@ func TestIsBlockedIPDisallowed(t *testing.T) { a := &accessCtx{} assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil)) - disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1") + disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1)) assert.True(t, disallowed) assert.Equal(t, "1.1.1.1", disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2)) assert.False(t, disallowed) assert.Empty(t, disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1)) assert.True(t, disallowed) assert.Equal(t, "2.2.0.0/16", disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1)) assert.False(t, disallowed) assert.Empty(t, disallowedRule) } diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index ab6bea275e7..0ada0640087 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -298,6 +298,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // IsBlockedIP - return TRUE if this client should be blocked -func (s *Server) IsBlockedIP(ip string) (bool, string) { +func (s *Server) IsBlockedIP(ip net.IP) (bool, string) { return s.access.IsBlockedIP(ip) } diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 77ae30a94d5..b4797dd0683 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -12,7 +12,7 @@ import ( ) func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) { - ip := IPStringFromAddr(d.Addr) + ip := IPFromAddr(d.Addr) disallowed, _ := s.access.IsBlockedIP(ip) if disallowed { log.Tracef("Client IP %s is blocked by settings", ip) diff --git a/internal/dnsforward/stats.go b/internal/dnsforward/stats.go index 822df6a0a7d..be45b0f98c0 100644 --- a/internal/dnsforward/stats.go +++ b/internal/dnsforward/stats.go @@ -36,7 +36,7 @@ func processQueryLogsAndStats(ctx *dnsContext) int { OrigAnswer: ctx.origResp, Result: ctx.result, Elapsed: elapsed, - ClientIP: ipFromAddr(d.Addr), + ClientIP: IPFromAddr(d.Addr), } switch d.Proto { diff --git a/internal/dnsforward/util.go b/internal/dnsforward/util.go index 3a8c1cb3027..7e9b55f0c43 100644 --- a/internal/dnsforward/util.go +++ b/internal/dnsforward/util.go @@ -8,8 +8,8 @@ import ( "github.com/AdguardTeam/golibs/utils" ) -// ipFromAddr gets IP address from addr. -func ipFromAddr(addr net.Addr) (ip net.IP) { +// IPFromAddr gets IP address from addr. +func IPFromAddr(addr net.Addr) (ip net.IP) { switch addr := addr.(type) { case *net.UDPAddr: return addr.IP @@ -23,7 +23,7 @@ func ipFromAddr(addr net.Addr) (ip net.IP) { // Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone: // https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261 func IPStringFromAddr(addr net.Addr) (ipstr string) { - if ip := ipFromAddr(addr); ip != nil { + if ip := IPFromAddr(addr); ip != nil { return ip.String() } diff --git a/internal/home/clients.go b/internal/home/clients.go index 3c6bfa48c1e..d7984d98c70 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -70,10 +70,11 @@ type ClientHost struct { } type clientsContainer struct { - list map[string]*Client // name -> client - idIndex map[string]*Client // IP -> client - ipHost map[string]*ClientHost // IP -> Hostname - lock sync.Mutex + list map[string]*Client // name -> client + idIndex map[string]*Client // IP -> client + // TODO(e.burkov): think about using maphash. + ipHost map[string]*ClientHost // IP -> Hostname + lock sync.Mutex allTags map[string]bool @@ -239,7 +240,7 @@ func (clients *clientsContainer) onHostsChanged() { } // Exists checks if client with this IP already exists -func (clients *clientsContainer) Exists(ip string, source clientSource) bool { +func (clients *clientsContainer) Exists(ip net.IP, source clientSource) bool { clients.lock.Lock() defer clients.lock.Unlock() @@ -248,7 +249,7 @@ func (clients *clientsContainer) Exists(ip string, source clientSource) bool { return true } - ch, ok := clients.ipHost[ip] + ch, ok := clients.ipHost[ip.String()] if !ok { return false } @@ -265,7 +266,7 @@ func stringArrayDup(a []string) []string { } // Find searches for a client by IP -func (clients *clientsContainer) Find(ip string) (Client, bool) { +func (clients *clientsContainer) Find(ip net.IP) (Client, bool) { clients.lock.Lock() defer clients.lock.Unlock() @@ -287,7 +288,7 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig clients.lock.Lock() defer clients.lock.Unlock() - c, ok := clients.findByIP(ip) + c, ok := clients.findByIP(net.ParseIP(ip)) if !ok { return nil } @@ -307,13 +308,12 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig } // Find searches for a client by IP (and does not lock anything) -func (clients *clientsContainer) findByIP(ip string) (Client, bool) { - ipAddr := net.ParseIP(ip) - if ipAddr == nil { +func (clients *clientsContainer) findByIP(ip net.IP) (Client, bool) { + if ip == nil { return Client{}, false } - c, ok := clients.idIndex[ip] + c, ok := clients.idIndex[ip.String()] if ok { return *c, true } @@ -324,7 +324,7 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) { if err != nil { continue } - if ipnet.Contains(ipAddr) { + if ipnet.Contains(ip) { return *c, true } } @@ -333,7 +333,7 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) { if clients.dhcpServer == nil { return Client{}, false } - macFound := clients.dhcpServer.FindMACbyIP(ipAddr) + macFound := clients.dhcpServer.FindMACbyIP(ip) if macFound == nil { return Client{}, false } @@ -353,16 +353,15 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) { } // FindAutoClient - search for an auto-client by IP -func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) { - ipAddr := net.ParseIP(ip) - if ipAddr == nil { +func (clients *clientsContainer) FindAutoClient(ip net.IP) (ClientHost, bool) { + if ip == nil { return ClientHost{}, false } clients.lock.Lock() defer clients.lock.Unlock() - ch, ok := clients.ipHost[ip] + ch, ok := clients.ipHost[ip.String()] if ok { return *ch, true } @@ -539,7 +538,7 @@ func (clients *clientsContainer) Update(name string, c Client) error { } // SetWhoisInfo - associate WHOIS information with a client -func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) { +func (clients *clientsContainer) SetWhoisInfo(ip net.IP, info [][]string) { clients.lock.Lock() defer clients.lock.Unlock() @@ -549,7 +548,7 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) { return } - ch, ok := clients.ipHost[ip] + ch, ok := clients.ipHost[ip.String()] if ok { ch.WhoisInfo = info log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo) @@ -561,7 +560,7 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) { Source: ClientSourceWHOIS, } ch.WhoisInfo = info - clients.ipHost[ip] = ch + clients.ipHost[ip.String()] = ch log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo) } diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index 69f2badabcd..94ff8009506 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -36,21 +36,21 @@ func TestClients(t *testing.T) { assert.True(t, b) assert.Nil(t, err) - c, b = clients.Find("1.1.1.1") + c, b = clients.Find(net.IPv4(1, 1, 1, 1)) assert.True(t, b) assert.Equal(t, c.Name, "client1") - c, b = clients.Find("1:2:3::4") + c, b = clients.Find(net.ParseIP("1:2:3::4")) assert.True(t, b) assert.Equal(t, c.Name, "client1") - c, b = clients.Find("2.2.2.2") + c, b = clients.Find(net.IPv4(2, 2, 2, 2)) assert.True(t, b) assert.Equal(t, c.Name, "client2") - assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile)) - assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) - assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile)) + assert.False(t, clients.Exists(net.IPv4(1, 2, 3, 4), ClientSourceHostsFile)) + assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile)) + assert.True(t, clients.Exists(net.IPv4(2, 2, 2, 2), ClientSourceHostsFile)) }) t.Run("add_fail_name", func(t *testing.T) { @@ -112,8 +112,8 @@ func TestClients(t *testing.T) { err := clients.Update("client1", c) assert.Nil(t, err) - assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) - assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) + assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile)) + assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile)) c = Client{ IDs: []string{"1.1.1.2"}, @@ -124,7 +124,7 @@ func TestClients(t *testing.T) { err = clients.Update("client1", c) assert.Nil(t, err) - c, b := clients.Find("1.1.1.2") + c, b := clients.Find(net.IPv4(1, 1, 1, 2)) assert.True(t, b) assert.Equal(t, "client1-renamed", c.Name) assert.Equal(t, "1.1.1.2", c.IDs[0]) @@ -135,7 +135,7 @@ func TestClients(t *testing.T) { t.Run("del_success", func(t *testing.T) { b := clients.Del("client1-renamed") assert.True(t, b) - assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) + assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile)) }) t.Run("del_fail", func(t *testing.T) { @@ -156,7 +156,7 @@ func TestClients(t *testing.T) { assert.True(t, b) assert.Nil(t, err) - assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) + assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile)) }) t.Run("addhost_fail", func(t *testing.T) { @@ -174,12 +174,12 @@ func TestClientsWhois(t *testing.T) { whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}} // set whois info on new client - clients.SetWhoisInfo("1.1.1.255", whois) + clients.SetWhoisInfo(net.IPv4(1, 1, 1, 255), whois) assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.255"].WhoisInfo[0][1]) // set whois info on existing auto-client _, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) - clients.SetWhoisInfo("1.1.1.1", whois) + clients.SetWhoisInfo(net.IPv4(1, 1, 1, 1), whois) assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.1"].WhoisInfo[0][1]) // Check that we cannot set whois info on a manually-added client @@ -188,7 +188,7 @@ func TestClientsWhois(t *testing.T) { Name: "client1", } _, _ = clients.Add(c) - clients.SetWhoisInfo("1.1.1.2", whois) + clients.SetWhoisInfo(net.IPv4(1, 1, 1, 2), whois) assert.Nil(t, clients.ipHost["1.1.1.2"]) _ = clients.Del("client1") } diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index d8cc3ee3c6e..a7f1925f99e 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -3,6 +3,7 @@ package home import ( "encoding/json" "fmt" + "net" "net/http" ) @@ -229,8 +230,8 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http q := r.URL.Query() data := []map[string]interface{}{} for i := 0; ; i++ { - ip := q.Get(fmt.Sprintf("ip%d", i)) - if len(ip) == 0 { + ip := net.ParseIP(q.Get(fmt.Sprintf("ip%d", i))) + if ip == nil { break } el := map[string]interface{}{} @@ -240,15 +241,15 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http if !ok { continue // a client with this IP isn't found } - cj := clientHostToJSON(ip, ch) + cj := clientHostToJSON(ip.String(), ch) cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip) - el[ip] = cj + el[ip.String()] = cj } else { cj := clientToJSON(&c) cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip) - el[ip] = cj + el[ip.String()] = cj } data = append(data, el) diff --git a/internal/home/config.go b/internal/home/config.go index f7b799dc507..6c8381ab63d 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -2,6 +2,7 @@ package home import ( "io/ioutil" + "net" "os" "path/filepath" "sync" @@ -40,7 +41,7 @@ type configuration struct { // It's reset after config is parsed fileData []byte - BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to + BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client Users []User `yaml:"users"` // Users that can access HTTP server @@ -74,7 +75,7 @@ type configuration struct { // field ordering is important -- yaml fields will mirror ordering from here type dnsConfig struct { - BindHost string `yaml:"bind_host"` + BindHost net.IP `yaml:"bind_host"` Port int `yaml:"port"` // time interval for statistics (in days) @@ -121,9 +122,9 @@ type tlsConfigSettings struct { var config = configuration{ BindPort: 3000, BetaBindPort: 0, - BindHost: "0.0.0.0", + BindHost: net.IP{0, 0, 0, 0}, DNS: dnsConfig{ - BindHost: "0.0.0.0", + BindHost: net.IP{0, 0, 0, 0}, Port: 53, StatsInterval: 1, FilteringConfig: dnsforward.FilteringConfig{ diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 7d67d140fd6..f76d397bfb3 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -31,7 +31,7 @@ type netInterfaceJSON struct { Name string `json:"name"` MTU int `json:"mtu"` HardwareAddr string `json:"hardware_address"` - Addresses []string `json:"ip_addresses"` + Addresses []net.IP `json:"ip_addresses"` Flags string `json:"flags"` } @@ -69,7 +69,7 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request type checkConfigReqEnt struct { Port int `json:"port"` - IP string `json:"ip"` + IP net.IP `json:"ip"` Autofix bool `json:"autofix"` } @@ -85,9 +85,9 @@ type checkConfigRespEnt struct { } type staticIPJSON struct { - Static string `json:"static"` - IP string `json:"ip"` - Error string `json:"error"` + Static string `json:"static"` + IP net.IPNet `json:"ip"` + Error string `json:"error"` } type checkConfigResp struct { @@ -107,14 +107,14 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) } if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort && reqData.Web.Port != config.BetaBindPort { - err = util.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port) + err = util.CheckPortAvailable(reqData.Web.IP.String(), reqData.Web.Port) if err != nil { respData.Web.Status = fmt.Sprintf("%v", err) } } if reqData.DNS.Port != 0 { - err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + err = util.CheckPacketPortAvailable(reqData.DNS.IP.String(), reqData.DNS.Port) if util.ErrorIsAddrInUse(err) { canAutofix := checkDNSStubListener() @@ -125,7 +125,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) log.Error("Couldn't disable DNSStubListener: %s", err) } - err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + err = util.CheckPacketPortAvailable(reqData.DNS.IP.String(), reqData.DNS.Port) canAutofix = false } @@ -133,12 +133,12 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) } if err == nil { - err = util.CheckPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + err = util.CheckPortAvailable(reqData.DNS.IP.String(), reqData.DNS.Port) } if err != nil { respData.DNS.Status = fmt.Sprintf("%v", err) - } else if reqData.DNS.IP != "0.0.0.0" { + } else if !reqData.DNS.IP.Equal(net.IP{0, 0, 0, 0}) { respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP) } } @@ -154,7 +154,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) // handleStaticIP - handles static IP request // It either checks if we have a static IP // Or if set=true, it tries to set it -func handleStaticIP(ip string, set bool) staticIPJSON { +func handleStaticIP(ip net.IP, set bool) staticIPJSON { resp := staticIPJSON{} interfaceName := util.GetInterfaceByIP(ip) @@ -262,7 +262,7 @@ func disableDNSStubListener() error { } type applyConfigReqEnt struct { - IP string `json:"ip"` + IP net.IP `json:"ip"` Port int `json:"port"` } @@ -297,29 +297,29 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { } restartHTTP := true - if config.BindHost == newSettings.Web.IP && config.BindPort == newSettings.Web.Port { + if config.BindHost.Equal(newSettings.Web.IP) && config.BindPort == newSettings.Web.Port { // no need to rebind restartHTTP = false } // validate that hosts and ports are bindable if restartHTTP { - err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port) + err = util.CheckPortAvailable(newSettings.Web.IP.String(), newSettings.Web.Port) if err != nil { httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s", - net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err) + net.JoinHostPort(newSettings.Web.IP.String(), strconv.Itoa(newSettings.Web.Port)), err) return } } - err = util.CheckPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port) + err = util.CheckPacketPortAvailable(newSettings.DNS.IP.String(), newSettings.DNS.Port) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return } - err = util.CheckPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port) + err = util.CheckPortAvailable(newSettings.DNS.IP.String(), newSettings.DNS.Port) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return @@ -392,7 +392,7 @@ func (web *Web) registerInstallHandlers() { // functionality will appear in default checkConfigReqEnt. type checkConfigReqEntBeta struct { Port int `json:"port"` - IP []string `json:"ip"` + IP []net.IP `json:"ip"` Autofix bool `json:"autofix"` } @@ -459,7 +459,7 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ // TODO(e.burkov): this should removed with the API v1 when the appropriate // functionality will appear in default applyConfigReqEnt. type applyConfigReqEntBeta struct { - IP []string `json:"ip"` + IP []net.IP `json:"ip"` Port int `json:"port"` } diff --git a/internal/home/dns.go b/internal/home/dns.go index a988062954f..9560599cda9 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -55,8 +55,8 @@ func initDNSServer() error { filterConf := config.DNS.DnsfilterConf bindhost := config.DNS.BindHost - if config.DNS.BindHost == "0.0.0.0" { - bindhost = "127.0.0.1" + if config.DNS.BindHost.Equal(net.IP{0, 0, 0, 0}) { + bindhost = net.IPv4(127, 0, 0, 1) } filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) filterConf.AutoHosts = &Context.autoHosts @@ -98,26 +98,24 @@ func isRunning() bool { } func onDNSRequest(d *proxy.DNSContext) { - ip := dnsforward.IPStringFromAddr(d.Addr) - if ip == "" { + ip := dnsforward.IPFromAddr(d.Addr) + if ip == nil { // This would be quite weird if we get here return } - ipAddr := net.ParseIP(ip) - if !ipAddr.IsLoopback() { + if !ip.IsLoopback() { Context.rdns.Begin(ip) } - if !Context.ipDetector.detectSpecialNetwork(ipAddr) { + if !Context.ipDetector.detectSpecialNetwork(ip) { Context.whois.Begin(ip) } } func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) { - bindHost := net.ParseIP(config.DNS.BindHost) newconfig = dnsforward.ServerConfig{ - UDPListenAddr: &net.UDPAddr{IP: bindHost, Port: config.DNS.Port}, - TCPListenAddr: &net.TCPAddr{IP: bindHost, Port: config.DNS.Port}, + UDPListenAddr: &net.UDPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port}, + TCPListenAddr: &net.TCPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port}, FilteringConfig: config.DNS.FilteringConfig, ConfigModified: onConfigModified, HTTPRegister: httpRegister, @@ -131,20 +129,20 @@ func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) { if tlsConf.PortDNSOverTLS != 0 { newconfig.TLSListenAddr = &net.TCPAddr{ - IP: bindHost, + IP: config.DNS.BindHost, Port: tlsConf.PortDNSOverTLS, } } if tlsConf.PortDNSOverQUIC != 0 { newconfig.QUICListenAddr = &net.UDPAddr{ - IP: bindHost, + IP: config.DNS.BindHost, Port: int(tlsConf.PortDNSOverQUIC), } } if tlsConf.PortDNSCrypt != 0 { - newconfig.DNSCryptConfig, err = newDNSCrypt(bindHost, tlsConf) + newconfig.DNSCryptConfig, err = newDNSCrypt(config.DNS.BindHost, tlsConf) if err != nil { // Don't wrap the error, because it's already // wrapped by newDNSCrypt. @@ -245,7 +243,7 @@ func getDNSEncryption() dnsEncryption { func getDNSAddresses() []string { dnsAddresses := []string{} - if config.DNS.BindHost == "0.0.0.0" { + if config.DNS.BindHost.Equal(net.IP{0, 0, 0, 0}) { ifaces, e := util.GetValidNetInterfacesForWeb() if e != nil { log.Error("Couldn't get network interfaces: %v", e) @@ -254,11 +252,11 @@ func getDNSAddresses() []string { for _, iface := range ifaces { for _, addr := range iface.Addresses { - addDNSAddress(&dnsAddresses, addr) + addDNSAddress(&dnsAddresses, addr.String()) } } } else { - addDNSAddress(&dnsAddresses, config.DNS.BindHost) + addDNSAddress(&dnsAddresses, config.DNS.BindHost.String()) } dnsEncryption := getDNSEncryption() @@ -284,7 +282,7 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri } setts.ClientIP = clientAddr - c, ok := Context.clients.Find(clientAddr) + c, ok := Context.clients.Find(net.ParseIP(clientAddr)) if !ok { return } @@ -332,10 +330,10 @@ func startDNSServer() error { for _, ip := range topClients { ipAddr := net.ParseIP(ip) if !ipAddr.IsLoopback() { - Context.rdns.Begin(ip) + Context.rdns.Begin(ipAddr) } if !Context.ipDetector.detectSpecialNetwork(ipAddr) { - Context.whois.Begin(ip) + Context.whois.Begin(ipAddr) } } diff --git a/internal/home/home.go b/internal/home/home.go index fbead57a124..8665894f698 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -206,7 +206,7 @@ func setupConfig(args options) { } // override bind host/port from the console - if args.bindHost != "" { + if args.bindHost != nil { config.BindHost = args.bindHost } if args.bindPort != 0 { @@ -581,30 +581,30 @@ func printHTTPAddresses(proto string) { } else { log.Printf("Go to https://%s:%s", tlsConf.ServerName, port) } - } else if config.BindHost == "0.0.0.0" { + } else if config.BindHost.Equal(net.IP{0, 0, 0, 0}) { log.Println("AdGuard Home is available on the following addresses:") ifaces, err := util.GetValidNetInterfacesForWeb() if err != nil { // That's weird, but we'll ignore it - log.Printf("Go to %s://%s", proto, net.JoinHostPort(config.BindHost, port)) + log.Printf("Go to %s://%s", proto, net.JoinHostPort(config.BindHost.String(), port)) if config.BetaBindPort != 0 { - log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(config.BindHost, strconv.Itoa(config.BetaBindPort))) + log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(config.BindHost.String(), strconv.Itoa(config.BetaBindPort))) } return } for _, iface := range ifaces { for _, addr := range iface.Addresses { - log.Printf("Go to %s://%s", proto, net.JoinHostPort(addr, strconv.Itoa(config.BindPort))) + log.Printf("Go to %s://%s", proto, net.JoinHostPort(addr.String(), strconv.Itoa(config.BindPort))) if config.BetaBindPort != 0 { - log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(addr, strconv.Itoa(config.BetaBindPort))) + log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(addr.String(), strconv.Itoa(config.BetaBindPort))) } } } } else { - log.Printf("Go to %s://%s", proto, net.JoinHostPort(config.BindHost, port)) + log.Printf("Go to %s://%s", proto, net.JoinHostPort(config.BindHost.String(), port)) if config.BetaBindPort != 0 { - log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(config.BindHost, strconv.Itoa(config.BetaBindPort))) + log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(config.BindHost.String(), strconv.Itoa(config.BetaBindPort))) } } } diff --git a/internal/home/options.go b/internal/home/options.go index 0493e85684b..897cbd082c9 100644 --- a/internal/home/options.go +++ b/internal/home/options.go @@ -2,6 +2,7 @@ package home import ( "fmt" + "net" "os" "strconv" @@ -13,7 +14,7 @@ type options struct { verbose bool // is verbose logging enabled configFilename string // path to the config file workDir string // path to the working directory where we will store the filters data and the querylog - bindHost string // host address to bind HTTP server on + bindHost net.IP // host address to bind HTTP server on bindPort int // port to serve HTTP pages on logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog pidFile string // File name to save PID to @@ -54,10 +55,19 @@ type arg struct { // against its zero value and return nil if the parameter value is // zero otherwise they return a string slice of the parameter +func ipSliceOrNil(ip net.IP) []string { + if ip == nil { + return nil + } + + return []string{ip.String()} +} + func stringSliceOrNil(s string) []string { if s == "" { return nil } + return []string{s} } @@ -65,6 +75,7 @@ func intSliceOrNil(i int) []string { if i == 0 { return nil } + return []string{strconv.Itoa(i)} } @@ -72,6 +83,7 @@ func boolSliceOrNil(b bool) []string { if b { return []string{} } + return nil } @@ -96,8 +108,8 @@ var workDirArg = arg{ var hostArg = arg{ "Host address to bind HTTP server on", "host", "h", - func(o options, v string) (options, error) { o.bindHost = v; return o, nil }, nil, nil, - func(o options) []string { return stringSliceOrNil(o.bindHost) }, + func(o options, v string) (options, error) { o.bindHost = net.ParseIP(v); return o, nil }, nil, nil, + func(o options) []string { return ipSliceOrNil(o.bindHost) }, } var portArg = arg{ diff --git a/internal/home/options_test.go b/internal/home/options_test.go index afaa873f24d..f24dc816f48 100644 --- a/internal/home/options_test.go +++ b/internal/home/options_test.go @@ -2,6 +2,7 @@ package home import ( "fmt" + "net" "testing" ) @@ -65,14 +66,14 @@ func TestParseWorkDir(t *testing.T) { } func TestParseBindHost(t *testing.T) { - if testParseOk(t).bindHost != "" { + if testParseOk(t).bindHost != nil { t.Fatal("empty is no host") } - if testParseOk(t, "-h", "addr").bindHost != "addr" { + if !testParseOk(t, "-h", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) { t.Fatal("-h is host") } testParseParamMissing(t, "-h") - if testParseOk(t, "--host", "addr").bindHost != "addr" { + if !testParseOk(t, "--host", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) { t.Fatal("--host is host") } testParseParamMissing(t, "--host") @@ -204,7 +205,7 @@ func TestSerializeWorkDir(t *testing.T) { } func TestSerializeBindHost(t *testing.T) { - testSerialize(t, options{bindHost: "addr"}, "-h", "addr") + testSerialize(t, options{bindHost: net.IP{1, 2, 3, 4}}, "-h", "1.2.3.4") } func TestSerializeBindPort(t *testing.T) { diff --git a/internal/home/rdns.go b/internal/home/rdns.go index 05df66ef591..3955ecb54d0 100644 --- a/internal/home/rdns.go +++ b/internal/home/rdns.go @@ -2,6 +2,7 @@ package home import ( "encoding/binary" + "net" "strings" "time" @@ -15,7 +16,7 @@ import ( type RDNS struct { dnsServer *dnsforward.Server clients *clientsContainer - ipChannel chan string // pass data from DNS request handling thread to rDNS thread + ipChannel chan net.IP // pass data from DNS request handling thread to rDNS thread // Contains IP addresses of clients to be resolved by rDNS // If IP address is resolved, it stays here while it's inside Clients. @@ -35,13 +36,13 @@ func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS { cconf.MaxCount = 10000 r.ipAddrs = cache.New(cconf) - r.ipChannel = make(chan string, 256) + r.ipChannel = make(chan net.IP, 256) go r.workerLoop() return &r } // Begin - add IP address to rDNS queue -func (r *RDNS) Begin(ip string) { +func (r *RDNS) Begin(ip net.IP) { now := uint64(time.Now().Unix()) expire := r.ipAddrs.Get([]byte(ip)) if len(expire) != 0 { @@ -70,7 +71,7 @@ func (r *RDNS) Begin(ip string) { } // Use rDNS to get hostname by IP address -func (r *RDNS) resolve(ip string) string { +func (r *RDNS) resolve(ip net.IP) string { log.Tracef("Resolving host for %s", ip) req := dns.Msg{} @@ -83,7 +84,7 @@ func (r *RDNS) resolve(ip string) string { }, } var err error - req.Question[0].Name, err = dns.ReverseAddr(ip) + req.Question[0].Name, err = dns.ReverseAddr(ip.String()) if err != nil { log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err) return "" @@ -123,6 +124,6 @@ func (r *RDNS) workerLoop() { continue } - _, _ = r.clients.AddHost(ip, host, ClientSourceRDNS) + _, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS) } } diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 516b0ed5d5c..53dd093dbe9 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -1,6 +1,7 @@ package home import ( + "net" "testing" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" @@ -16,6 +17,6 @@ func TestResolveRDNS(t *testing.T) { clients := &clientsContainer{} rdns := InitRDNS(dns, clients) - r := rdns.resolve("1.1.1.1") + r := rdns.resolve(net.IP{1, 1, 1, 1}) assert.Equal(t, "one.one.one.one", r, r) } diff --git a/internal/home/web.go b/internal/home/web.go index 83fe9db4be5..541f87d001d 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -31,7 +31,7 @@ const ( type webConfig struct { firstRun bool - BindHost string + BindHost net.IP BindPort int BetaBindPort int PortHTTPS int @@ -114,7 +114,7 @@ func WebCheckPortAvailable(port int) bool { alreadyRunning = true } if !alreadyRunning { - err := util.CheckPortAvailable(config.BindHost, port) + err := util.CheckPortAvailable(config.BindHost.String(), port) if err != nil { return false } @@ -164,7 +164,7 @@ func (web *Web) Start() { // we need to have new instance, because after Shutdown() the Server is not usable web.httpServer = &http.Server{ ErrorLog: log.StdLog("web: http", log.DEBUG), - Addr: net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BindPort)), + Addr: net.JoinHostPort(web.conf.BindHost.String(), strconv.Itoa(web.conf.BindPort)), Handler: withMiddlewares(Context.mux, limitRequestBody), ReadTimeout: web.conf.ReadTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout, @@ -177,7 +177,7 @@ func (web *Web) Start() { if web.conf.BetaBindPort != 0 { web.httpServerBeta = &http.Server{ ErrorLog: log.StdLog("web: http", log.DEBUG), - Addr: net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BetaBindPort)), + Addr: net.JoinHostPort(web.conf.BindHost.String(), strconv.Itoa(web.conf.BetaBindPort)), Handler: withMiddlewares(Context.mux, limitRequestBody, web.wrapIndexBeta), ReadTimeout: web.conf.ReadTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout, @@ -236,7 +236,7 @@ func (web *Web) tlsServerLoop() { web.httpsServer.cond.L.Unlock() // prepare HTTPS server - address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.PortHTTPS)) + address := net.JoinHostPort(web.conf.BindHost.String(), strconv.Itoa(web.conf.PortHTTPS)) web.httpsServer.server = &http.Server{ ErrorLog: log.StdLog("web: https", log.DEBUG), Addr: address, diff --git a/internal/home/whois.go b/internal/home/whois.go index 4884d776dc0..6c40ed54e98 100644 --- a/internal/home/whois.go +++ b/internal/home/whois.go @@ -26,7 +26,7 @@ const ( // Whois - module context type Whois struct { clients *clientsContainer - ipChan chan string + ipChan chan net.IP timeoutMsec uint // Contains IP addresses of clients @@ -46,7 +46,7 @@ func initWhois(clients *clientsContainer) *Whois { cconf.MaxCount = 10000 w.ipAddrs = cache.New(cconf) - w.ipChan = make(chan string, 255) + w.ipChan = make(chan net.IP, 255) go w.workerLoop() return &w } @@ -183,9 +183,9 @@ func (w *Whois) queryAll(target string) (string, error) { } // Request WHOIS information -func (w *Whois) process(ip string) [][]string { +func (w *Whois) process(ip net.IP) [][]string { data := [][]string{} - resp, err := w.queryAll(ip) + resp, err := w.queryAll(ip.String()) if err != nil { log.Debug("Whois: error: %s IP:%s", err, ip) return data @@ -209,7 +209,7 @@ func (w *Whois) process(ip string) [][]string { } // Begin - begin requesting WHOIS info -func (w *Whois) Begin(ip string) { +func (w *Whois) Begin(ip net.IP) { now := uint64(time.Now().Unix()) expire := w.ipAddrs.Get([]byte(ip)) if len(expire) != 0 { diff --git a/internal/sysutil/net_linux.go b/internal/sysutil/net_linux.go index 06d27eb2845..8f47cf428fc 100644 --- a/internal/sysutil/net_linux.go +++ b/internal/sysutil/net_linux.go @@ -119,17 +119,13 @@ func ifacesStaticConfig(r io.Reader, ifaceName string) (has bool, err error) { } func ifaceSetStaticIP(ifaceName string) (err error) { - ip := util.GetSubnet(ifaceName) - if len(ip) == 0 { + ipNet := util.GetSubnet(ifaceName) + if ipNet.IP == nil { return errors.New("can't get IP address") } - ip4, _, err := net.ParseCIDR(ip) - if err != nil { - return err - } gatewayIP := GatewayIP(ifaceName) - add := updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, ip4) + add := updateStaticIPdhcpcdConf(ifaceName, ipNet.String(), gatewayIP, ipNet.IP) body, err := ioutil.ReadFile("/etc/dhcpcd.conf") if err != nil { diff --git a/internal/util/network.go b/internal/util/network.go index 1731ed08c93..0cebb121d12 100644 --- a/internal/util/network.go +++ b/internal/util/network.go @@ -15,12 +15,12 @@ import ( // NetInterface represents a list of network interfaces type NetInterface struct { - Name string // Network interface name - MTU int // MTU - HardwareAddr string // Hardware address - Addresses []string // Array with the network interface addresses - Subnets []string // Array with CIDR addresses of this network interface - Flags string // Network interface flags (up, broadcast, etc) + Name string // Network interface name + MTU int // MTU + HardwareAddr string // Hardware address + Addresses []net.IP // Array with the network interface addresses + Subnets []net.IPNet // Array with CIDR addresses of this network interface + Flags string // Network interface flags (up, broadcast, etc) } // GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP @@ -78,8 +78,8 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) { if ipNet.IP.IsLinkLocalUnicast() { continue } - netIface.Addresses = append(netIface.Addresses, ipNet.IP.String()) - netIface.Subnets = append(netIface.Subnets, ipNet.String()) + netIface.Addresses = append(netIface.Addresses, ipNet.IP) + netIface.Subnets = append(netIface.Subnets, *ipNet) } // Discard interfaces with no addresses @@ -91,8 +91,8 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) { return netInterfaces, nil } -// GetInterfaceByIP - Get interface name by its IP address. -func GetInterfaceByIP(ip string) string { +// GetInterfaceByIP returns the name of interface containing provided ip. +func GetInterfaceByIP(ip net.IP) string { ifaces, err := GetValidNetInterfacesForWeb() if err != nil { return "" @@ -100,7 +100,7 @@ func GetInterfaceByIP(ip string) string { for _, iface := range ifaces { for _, addr := range iface.Addresses { - if ip == addr { + if ip.Equal(addr) { return iface.Name } } @@ -111,11 +111,11 @@ func GetInterfaceByIP(ip string) string { // GetSubnet - Get IP address with netmask for the specified interface // Returns an empty string if it fails to find it -func GetSubnet(ifaceName string) string { +func GetSubnet(ifaceName string) net.IPNet { netIfaces, err := GetValidNetInterfacesForWeb() if err != nil { log.Error("Could not get network interfaces info: %v", err) - return "" + return net.IPNet{} } for _, netIface := range netIfaces { @@ -124,7 +124,7 @@ func GetSubnet(ifaceName string) string { } } - return "" + return net.IPNet{} } // CheckPortAvailable - check if TCP port is available