diff --git a/internal/dhcpd/dhcpd_test.go b/internal/dhcpd/dhcpd_test.go index 58b88bda68f..1d3aab21058 100644 --- a/internal/dhcpd/dhcpd_test.go +++ b/internal/dhcpd/dhcpd_test.go @@ -28,27 +28,27 @@ func TestDB(t *testing.T) { conf := V4ServerConf{ Enabled: true, - RangeStart: "192.168.10.100", - RangeEnd: "192.168.10.200", - GatewayIP: "192.168.10.1", - SubnetMask: "255.255.255.0", + RangeStart: net.IP{192, 168, 10, 100}, + RangeEnd: net.IP{192, 168, 10, 200}, + GatewayIP: net.IP{192, 168, 10, 1}, + SubnetMask: net.IP{255, 255, 255, 0}, notify: testNotify, } s.srv4, err = v4Create(conf) - assert.True(t, err == nil) + assert.Nil(t, err) s.srv6, err = v6Create(V6ServerConf{}) - assert.True(t, err == nil) + assert.Nil(t, err) l := Lease{} - l.IP = net.ParseIP("192.168.10.100").To4() + l.IP = net.IP{192, 168, 10, 100} l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") exp1 := time.Now().Add(time.Hour) l.Expiry = exp1 s.srv4.(*v4Server).addLease(&l) l2 := Lease{} - l2.IP = net.ParseIP("192.168.10.101").To4() + l2.IP = net.IP{192, 168, 10, 101} l2.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:bb") s.srv4.AddStaticLease(l2) @@ -62,7 +62,7 @@ func TestDB(t *testing.T) { assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String()) assert.Equal(t, "192.168.10.101", ll[0].IP.String()) - assert.Equal(t, int64(leaseExpireStatic), ll[0].Expiry.Unix()) + assert.EqualValues(t, leaseExpireStatic, ll[0].Expiry.Unix()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String()) assert.Equal(t, "192.168.10.100", ll[1].IP.String()) @@ -75,8 +75,8 @@ func TestIsValidSubnetMask(t *testing.T) { assert.True(t, isValidSubnetMask([]byte{255, 255, 255, 0})) assert.True(t, isValidSubnetMask([]byte{255, 255, 254, 0})) assert.True(t, isValidSubnetMask([]byte{255, 255, 252, 0})) - assert.True(t, !isValidSubnetMask([]byte{255, 255, 253, 0})) - assert.True(t, !isValidSubnetMask([]byte{255, 255, 255, 1})) + assert.False(t, isValidSubnetMask([]byte{255, 255, 253, 0})) + assert.False(t, isValidSubnetMask([]byte{255, 255, 255, 1})) } func TestNormalizeLeases(t *testing.T) { @@ -100,7 +100,7 @@ func TestNormalizeLeases(t *testing.T) { leases := normalizeLeases(staticLeases, dynLeases) - assert.True(t, len(leases) == 3) + assert.Len(t, leases, 3) assert.True(t, bytes.Equal(leases[0].HWAddr, []byte{1, 2, 3, 4})) assert.True(t, bytes.Equal(leases[0].IP, []byte{0, 2, 3, 4})) assert.True(t, bytes.Equal(leases[1].HWAddr, []byte{2, 2, 3, 4})) @@ -109,22 +109,22 @@ func TestNormalizeLeases(t *testing.T) { func TestOptions(t *testing.T) { code, val := parseOptionString(" 12 hex abcdef ") - assert.Equal(t, uint8(12), code) + assert.EqualValues(t, 12, code) assert.True(t, bytes.Equal([]byte{0xab, 0xcd, 0xef}, val)) code, _ = parseOptionString(" 12 hex abcdef1 ") - assert.Equal(t, uint8(0), code) + assert.EqualValues(t, 0, code) code, val = parseOptionString("123 ip 1.2.3.4") - assert.Equal(t, uint8(123), code) + assert.EqualValues(t, 123, code) assert.Equal(t, "1.2.3.4", net.IP(string(val)).String()) code, _ = parseOptionString("256 ip 1.1.1.1") - assert.Equal(t, uint8(0), code) + assert.EqualValues(t, 0, code) code, _ = parseOptionString("-1 ip 1.1.1.1") - assert.Equal(t, uint8(0), code) + assert.EqualValues(t, 0, code) code, _ = parseOptionString("12 ip 1.1.1.1x") - assert.Equal(t, uint8(0), code) + assert.EqualValues(t, 0, code) code, _ = parseOptionString("12 x 1.1.1.1") - assert.Equal(t, uint8(0), code) + assert.EqualValues(t, 0, code) } diff --git a/internal/dhcpd/dhcphttp.go b/internal/dhcpd/dhcphttp.go index 06a4cc79efb..e35322f889d 100644 --- a/internal/dhcpd/dhcphttp.go +++ b/internal/dhcpd/dhcphttp.go @@ -42,10 +42,10 @@ func convertLeases(inputLeases []Lease, includeExpires bool) []map[string]string } type v4ServerConfJSON struct { - GatewayIP string `json:"gateway_ip"` - SubnetMask string `json:"subnet_mask"` - RangeStart string `json:"range_start"` - RangeEnd string `json:"range_end"` + GatewayIP net.IP `json:"gateway_ip"` + SubnetMask net.IP `json:"subnet_mask"` + RangeStart net.IP `json:"range_start"` + RangeEnd net.IP `json:"range_end"` LeaseDuration uint32 `json:"lease_duration"` } @@ -61,10 +61,10 @@ func v4ServerConfToJSON(c V4ServerConf) v4ServerConfJSON { func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf { return V4ServerConf{ - GatewayIP: j.GatewayIP, - SubnetMask: j.SubnetMask, - RangeStart: j.RangeStart, - RangeEnd: j.RangeEnd, + GatewayIP: j.GatewayIP.To4(), + SubnetMask: j.SubnetMask.To4(), + RangeStart: j.RangeStart.To4(), + RangeEnd: j.RangeEnd.To4(), LeaseDuration: j.LeaseDuration, } } @@ -117,7 +117,7 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) { type staticLeaseJSON struct { HWAddr string `json:"mac"` - IP string `json:"ip"` + IP net.IP `json:"ip"` Hostname string `json:"hostname"` } @@ -225,10 +225,10 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { type netInterfaceJSON struct { Name string `json:"name"` - GatewayIP string `json:"gateway_ip"` + GatewayIP net.IP `json:"gateway_ip"` HardwareAddr string `json:"hardware_address"` - Addrs4 []string `json:"ipv4_addresses"` - Addrs6 []string `json:"ipv6_addresses"` + Addrs4 []net.IP `json:"ipv4_addresses"` + Addrs6 []net.IP `json:"ipv6_addresses"` Flags string `json:"flags"` } @@ -277,9 +277,9 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { continue } if ipnet.IP.To4() != nil { - jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP.String()) + jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP) } else { - jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP.String()) + jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP) } } if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 { @@ -375,50 +375,46 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request err := json.NewDecoder(r.Body).Decode(&lj) if err != nil { httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + return } - ip := net.ParseIP(lj.IP) - if ip != nil && ip.To4() == nil { - mac, err := net.ParseMAC(lj.HWAddr) + if lj.IP == nil { + httpError(r, w, http.StatusBadRequest, "invalid IP") + + return + } + + ip4 := lj.IP.To4() + + mac, err := net.ParseMAC(lj.HWAddr) + lease := Lease{ + HWAddr: mac, + } + + if ip4 == nil { + lease.IP = lj.IP.To16() + if err != nil { httpError(r, w, http.StatusBadRequest, "invalid MAC") - return - } - lease := Lease{ - IP: ip, - HWAddr: mac, + return } err = s.srv6.AddStaticLease(lease) if err != nil { httpError(r, w, http.StatusBadRequest, "%s", err) - return } - return - } - ip, _ = parseIPv4(lj.IP) - if ip == nil { - httpError(r, w, http.StatusBadRequest, "invalid IP") - return - } - - mac, err := net.ParseMAC(lj.HWAddr) - if err != nil { - httpError(r, w, http.StatusBadRequest, "invalid MAC") return } - lease := Lease{ - IP: ip, - HWAddr: mac, - Hostname: lj.Hostname, - } + lease.IP = ip4 + lease.Hostname = lj.Hostname err = s.srv4.AddStaticLease(lease) if err != nil { httpError(r, w, http.StatusBadRequest, "%s", err) + return } } @@ -428,46 +424,46 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ err := json.NewDecoder(r.Body).Decode(&lj) if err != nil { httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + return } - ip := net.ParseIP(lj.IP) - if ip != nil && ip.To4() == nil { - mac, err := net.ParseMAC(lj.HWAddr) + if lj.IP == nil { + httpError(r, w, http.StatusBadRequest, "invalid IP") + + return + } + + ip4 := lj.IP.To4() + + mac, err := net.ParseMAC(lj.HWAddr) + lease := Lease{ + HWAddr: mac, + } + + if ip4 == nil { + lease.IP = lj.IP.To16() + if err != nil { httpError(r, w, http.StatusBadRequest, "invalid MAC") - return - } - lease := Lease{ - IP: ip, - HWAddr: mac, + return } err = s.srv6.RemoveStaticLease(lease) if err != nil { httpError(r, w, http.StatusBadRequest, "%s", err) - return } - return - } - ip, _ = parseIPv4(lj.IP) - if ip == nil { - httpError(r, w, http.StatusBadRequest, "invalid IP") return } - mac, _ := net.ParseMAC(lj.HWAddr) - - lease := Lease{ - IP: ip, - HWAddr: mac, - Hostname: lj.Hostname, - } + lease.IP = ip4 + lease.Hostname = lj.Hostname err = s.srv4.RemoveStaticLease(lease) if err != nil { httpError(r, w, http.StatusBadRequest, "%s", err) + return } } diff --git a/internal/dhcpd/helpers.go b/internal/dhcpd/helpers.go index aafda988b56..28856b5f3b8 100644 --- a/internal/dhcpd/helpers.go +++ b/internal/dhcpd/helpers.go @@ -14,15 +14,17 @@ func isTimeout(err error) bool { return operr.Timeout() } -func parseIPv4(text string) (net.IP, error) { - result := net.ParseIP(text) - if result == nil { - return nil, fmt.Errorf("%s is not an IP address", text) +func tryTo4(ip net.IP) (ip4 net.IP, err error) { + if ip == nil { + return nil, fmt.Errorf("%v is not an IP address", ip) } - if result.To4() == nil { - return nil, fmt.Errorf("%s is not an IPv4 address", text) + + ip4 = ip.To4() + if ip4 == nil { + return nil, fmt.Errorf("%v is not an IPv4 address", ip) } - return result.To4(), nil + + return ip4, nil } // Return TRUE if subnet mask is correct (e.g. 255.255.255.0) diff --git a/internal/dhcpd/server.go b/internal/dhcpd/server.go index 240715ca375..20f6cad3154 100644 --- a/internal/dhcpd/server.go +++ b/internal/dhcpd/server.go @@ -36,13 +36,13 @@ type V4ServerConf struct { Enabled bool `yaml:"-"` InterfaceName string `yaml:"-"` - GatewayIP string `yaml:"gateway_ip"` - SubnetMask string `yaml:"subnet_mask"` + GatewayIP net.IP `yaml:"gateway_ip"` + SubnetMask net.IP `yaml:"subnet_mask"` // The first & the last IP address for dynamic leases // Bytes [0..2] of the last allowed IP address must match the first IP - RangeStart string `yaml:"range_start"` - RangeEnd string `yaml:"range_end"` + RangeStart net.IP `yaml:"range_start"` + RangeEnd net.IP `yaml:"range_end"` LeaseDuration uint32 `yaml:"lease_duration"` // in seconds diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go index 81ba3a1d7fb..038d8f9c06f 100644 --- a/internal/dhcpd/v4.go +++ b/internal/dhcpd/v4.go @@ -589,7 +589,7 @@ func (s *v4Server) Start() error { s.conf.dnsIPAddrs = dnsIPAddrs laddr := &net.UDPAddr{ - IP: net.ParseIP("0.0.0.0"), + IP: net.IP{0, 0, 0, 0}, Port: dhcpv4.ServerPort, } s.srv, err = server4.NewServer(iface.Name, laddr, s.packetHandler, server4.WithDebugLogger()) @@ -632,19 +632,18 @@ func v4Create(conf V4ServerConf) (DHCPServer, error) { } var err error - s.conf.routerIP, err = parseIPv4(s.conf.GatewayIP) + s.conf.routerIP, err = tryTo4(s.conf.GatewayIP) if err != nil { return s, fmt.Errorf("dhcpv4: %w", err) } - subnet, err := parseIPv4(s.conf.SubnetMask) - if err != nil || !isValidSubnetMask(subnet) { - return s, fmt.Errorf("dhcpv4: invalid subnet mask: %s", s.conf.SubnetMask) + if s.conf.SubnetMask == nil { + return s, fmt.Errorf("dhcpv4: invalid subnet mask: %v", s.conf.SubnetMask) } s.conf.subnetMask = make([]byte, 4) - copy(s.conf.subnetMask, subnet) + copy(s.conf.subnetMask, s.conf.SubnetMask.To4()) - s.conf.ipStart, err = parseIPv4(conf.RangeStart) + s.conf.ipStart, err = tryTo4(conf.RangeStart) if s.conf.ipStart == nil { return s, fmt.Errorf("dhcpv4: %w", err) } @@ -652,7 +651,7 @@ func v4Create(conf V4ServerConf) (DHCPServer, error) { return s, fmt.Errorf("dhcpv4: invalid range start IP") } - s.conf.ipEnd, err = parseIPv4(conf.RangeEnd) + s.conf.ipEnd, err = tryTo4(conf.RangeEnd) if s.conf.ipEnd == nil { return s, fmt.Errorf("dhcpv4: %w", err) } diff --git a/internal/dhcpd/v4_test.go b/internal/dhcpd/v4_test.go index fe3ac2ddf48..e3086026a5f 100644 --- a/internal/dhcpd/v4_test.go +++ b/internal/dhcpd/v4_test.go @@ -16,119 +16,119 @@ func notify4(flags uint32) { func TestV4StaticLeaseAddRemove(t *testing.T) { conf := V4ServerConf{ Enabled: true, - RangeStart: "192.168.10.100", - RangeEnd: "192.168.10.200", - GatewayIP: "192.168.10.1", - SubnetMask: "255.255.255.0", + RangeStart: net.IP{192, 168, 10, 100}, + RangeEnd: net.IP{192, 168, 10, 200}, + GatewayIP: net.IP{192, 168, 10, 1}, + SubnetMask: net.IP{255, 255, 255, 0}, notify: notify4, } s, err := v4Create(conf) - assert.True(t, err == nil) + assert.Nil(t, err) ls := s.GetLeases(LeasesStatic) - assert.Equal(t, 0, len(ls)) + assert.Empty(t, ls) // add static lease l := Lease{} - l.IP = net.ParseIP("192.168.10.150").To4() + l.IP = net.IP{192, 168, 10, 150} l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.True(t, s.AddStaticLease(l) == nil) + assert.Nil(t, s.AddStaticLease(l)) // try to add the same static lease - fail - assert.True(t, s.AddStaticLease(l) != nil) + assert.NotNil(t, s.AddStaticLease(l)) // check ls = s.GetLeases(LeasesStatic) - assert.Equal(t, 1, len(ls)) + assert.Len(t, ls, 1) assert.Equal(t, "192.168.10.150", ls[0].IP.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) - assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic) + assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) // try to remove static lease - fail - l.IP = net.ParseIP("192.168.10.110").To4() + l.IP = net.IP{192, 168, 10, 110} l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.True(t, s.RemoveStaticLease(l) != nil) + assert.NotNil(t, s.RemoveStaticLease(l)) // remove static lease - l.IP = net.ParseIP("192.168.10.150").To4() + l.IP = net.IP{192, 168, 10, 150} l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.True(t, s.RemoveStaticLease(l) == nil) + assert.Nil(t, s.RemoveStaticLease(l)) // check ls = s.GetLeases(LeasesStatic) - assert.Equal(t, 0, len(ls)) + assert.Empty(t, ls) } func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) { conf := V4ServerConf{ Enabled: true, - RangeStart: "192.168.10.100", - RangeEnd: "192.168.10.200", - GatewayIP: "192.168.10.1", - SubnetMask: "255.255.255.0", + RangeStart: net.IP{192, 168, 10, 100}, + RangeEnd: net.IP{192, 168, 10, 200}, + GatewayIP: net.IP{192, 168, 10, 1}, + SubnetMask: net.IP{255, 255, 255, 0}, notify: notify4, } sIface, err := v4Create(conf) s := sIface.(*v4Server) - assert.True(t, err == nil) + assert.Nil(t, err) // add dynamic lease ld := Lease{} - ld.IP = net.ParseIP("192.168.10.150").To4() + ld.IP = net.IP{192, 168, 10, 150} ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa") s.addLease(&ld) // add dynamic lease { ld := Lease{} - ld.IP = net.ParseIP("192.168.10.151").To4() + ld.IP = net.IP{192, 168, 10, 151} ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa") s.addLease(&ld) } // add static lease with the same IP l := Lease{} - l.IP = net.ParseIP("192.168.10.150").To4() + l.IP = net.IP{192, 168, 10, 150} l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa") - assert.True(t, s.AddStaticLease(l) == nil) + assert.Nil(t, s.AddStaticLease(l)) // add static lease with the same MAC l = Lease{} - l.IP = net.ParseIP("192.168.10.152").To4() + l.IP = net.IP{192, 168, 10, 152} l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa") - assert.True(t, s.AddStaticLease(l) == nil) + assert.Nil(t, s.AddStaticLease(l)) // check ls := s.GetLeases(LeasesStatic) - assert.Equal(t, 2, len(ls)) + assert.Len(t, ls, 2) assert.Equal(t, "192.168.10.150", ls[0].IP.String()) assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) - assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic) + assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) assert.Equal(t, "192.168.10.152", ls[1].IP.String()) assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String()) - assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic) + assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix()) } func TestV4StaticLeaseGet(t *testing.T) { conf := V4ServerConf{ Enabled: true, - RangeStart: "192.168.10.100", - RangeEnd: "192.168.10.200", - GatewayIP: "192.168.10.1", - SubnetMask: "255.255.255.0", + RangeStart: net.IP{192, 168, 10, 100}, + RangeEnd: net.IP{192, 168, 10, 200}, + GatewayIP: net.IP{192, 168, 10, 1}, + SubnetMask: net.IP{255, 255, 255, 0}, notify: notify4, } sIface, err := v4Create(conf) s := sIface.(*v4Server) - assert.True(t, err == nil) - s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()} + assert.Nil(t, err) + s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}} l := Lease{} - l.IP = net.ParseIP("192.168.10.150").To4() + l.IP = net.IP{192, 168, 10, 150} l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.True(t, s.AddStaticLease(l) == nil) + assert.Nil(t, s.AddStaticLease(l)) // "Discover" mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") @@ -160,12 +160,12 @@ func TestV4StaticLeaseGet(t *testing.T) { assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) dnsAddrs := resp.DNS() - assert.Equal(t, 1, len(dnsAddrs)) + assert.Len(t, dnsAddrs, 1) assert.Equal(t, "192.168.10.1", dnsAddrs[0].String()) // check lease ls := s.GetLeases(LeasesStatic) - assert.Equal(t, 1, len(ls)) + assert.Len(t, ls, 1) assert.Equal(t, "192.168.10.150", ls[0].IP.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) } @@ -173,10 +173,10 @@ func TestV4StaticLeaseGet(t *testing.T) { func TestV4DynamicLeaseGet(t *testing.T) { conf := V4ServerConf{ Enabled: true, - RangeStart: "192.168.10.100", - RangeEnd: "192.168.10.200", - GatewayIP: "192.168.10.1", - SubnetMask: "255.255.255.0", + RangeStart: net.IP{192, 168, 10, 100}, + RangeEnd: net.IP{192, 168, 10, 200}, + GatewayIP: net.IP{192, 168, 10, 1}, + SubnetMask: net.IP{255, 255, 255, 0}, notify: notify4, Options: []string{ "81 hex 303132", @@ -185,8 +185,8 @@ func TestV4DynamicLeaseGet(t *testing.T) { } sIface, err := v4Create(conf) s := sIface.(*v4Server) - assert.True(t, err == nil) - s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()} + assert.Nil(t, err) + s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}} // "Discover" mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") @@ -220,19 +220,19 @@ func TestV4DynamicLeaseGet(t *testing.T) { assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) dnsAddrs := resp.DNS() - assert.Equal(t, 1, len(dnsAddrs)) + assert.Len(t, dnsAddrs, 1) assert.Equal(t, "192.168.10.1", dnsAddrs[0].String()) // check lease ls := s.GetLeases(LeasesDynamic) - assert.Equal(t, 1, len(ls)) + assert.Len(t, ls, 1) assert.Equal(t, "192.168.10.100", ls[0].IP.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) - start := net.ParseIP("192.168.10.100").To4() - stop := net.ParseIP("192.168.10.200").To4() - assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.10.99").To4())) - assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.100").To4())) - assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.201").To4())) - assert.True(t, ip4InRange(start, stop, net.ParseIP("192.168.10.100").To4())) + start := net.IP{192, 168, 10, 100} + stop := net.IP{192, 168, 10, 200} + assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 10, 99})) + assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 11, 100})) + assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 11, 201})) + assert.True(t, ip4InRange(start, stop, net.IP{192, 168, 10, 100})) } diff --git a/internal/dhcpd/v6_test.go b/internal/dhcpd/v6_test.go index 7d7dd6787e9..fd3dd89b8a5 100644 --- a/internal/dhcpd/v6_test.go +++ b/internal/dhcpd/v6_test.go @@ -21,40 +21,40 @@ func TestV6StaticLeaseAddRemove(t *testing.T) { notify: notify6, } s, err := v6Create(conf) - assert.True(t, err == nil) + assert.Nil(t, err) ls := s.GetLeases(LeasesStatic) - assert.Equal(t, 0, len(ls)) + assert.Empty(t, ls) // add static lease l := Lease{} l.IP = net.ParseIP("2001::1") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.True(t, s.AddStaticLease(l) == nil) + assert.Nil(t, s.AddStaticLease(l)) // try to add static lease - fail - assert.True(t, s.AddStaticLease(l) != nil) + assert.NotNil(t, s.AddStaticLease(l)) // check ls = s.GetLeases(LeasesStatic) - assert.Equal(t, 1, len(ls)) + assert.Len(t, ls, 1) assert.Equal(t, "2001::1", ls[0].IP.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) - assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic) + assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) // try to remove static lease - fail l.IP = net.ParseIP("2001::2") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.True(t, s.RemoveStaticLease(l) != nil) + assert.NotNil(t, s.RemoveStaticLease(l)) // remove static lease l.IP = net.ParseIP("2001::1") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.True(t, s.RemoveStaticLease(l) == nil) + assert.Nil(t, s.RemoveStaticLease(l)) // check ls = s.GetLeases(LeasesStatic) - assert.Equal(t, 0, len(ls)) + assert.Empty(t, ls) } func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) { @@ -65,7 +65,7 @@ func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) { } sIface, err := v6Create(conf) s := sIface.(*v6Server) - assert.True(t, err == nil) + assert.Nil(t, err) // add dynamic lease ld := Lease{} @@ -85,25 +85,25 @@ func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) { l := Lease{} l.IP = net.ParseIP("2001::1") l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa") - assert.True(t, s.AddStaticLease(l) == nil) + assert.Nil(t, s.AddStaticLease(l)) // add static lease with the same MAC l = Lease{} l.IP = net.ParseIP("2001::3") l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa") - assert.True(t, s.AddStaticLease(l) == nil) + assert.Nil(t, s.AddStaticLease(l)) // check ls := s.GetLeases(LeasesStatic) - assert.Equal(t, 2, len(ls)) + assert.Len(t, ls, 2) assert.Equal(t, "2001::1", ls[0].IP.String()) assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) - assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic) + assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) assert.Equal(t, "2001::3", ls[1].IP.String()) assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String()) - assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic) + assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix()) } func TestV6GetLease(t *testing.T) { @@ -114,7 +114,7 @@ func TestV6GetLease(t *testing.T) { } sIface, err := v6Create(conf) s := sIface.(*v6Server) - assert.True(t, err == nil) + assert.Nil(t, err) s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")} s.sid = dhcpv6.Duid{ Type: dhcpv6.DUID_LLT, @@ -125,7 +125,7 @@ func TestV6GetLease(t *testing.T) { l := Lease{} l.IP = net.ParseIP("2001::1") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.True(t, s.AddStaticLease(l) == nil) + assert.Nil(t, s.AddStaticLease(l)) // "Solicit" mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") @@ -156,12 +156,12 @@ func TestV6GetLease(t *testing.T) { assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds()) dnsAddrs := resp.Options.DNS() - assert.Equal(t, 1, len(dnsAddrs)) + assert.Len(t, dnsAddrs, 1) assert.Equal(t, "2000::1", dnsAddrs[0].String()) // check lease ls := s.GetLeases(LeasesStatic) - assert.Equal(t, 1, len(ls)) + assert.Len(t, ls, 1) assert.Equal(t, "2001::1", ls[0].IP.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) } @@ -174,7 +174,7 @@ func TestV6GetDynamicLease(t *testing.T) { } sIface, err := v6Create(conf) s := sIface.(*v6Server) - assert.True(t, err == nil) + assert.Nil(t, err) s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")} s.sid = dhcpv6.Duid{ Type: dhcpv6.DUID_LLT, @@ -209,17 +209,17 @@ func TestV6GetDynamicLease(t *testing.T) { assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String()) dnsAddrs := resp.Options.DNS() - assert.Equal(t, 1, len(dnsAddrs)) + assert.Len(t, dnsAddrs, 1) assert.Equal(t, "2000::1", dnsAddrs[0].String()) // check lease ls := s.GetLeases(LeasesDynamic) - assert.Equal(t, 1, len(ls)) + assert.Len(t, ls, 1) assert.Equal(t, "2001::2", ls[0].IP.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) - assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::1"))) - assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2002::2"))) + assert.False(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::1"))) + assert.False(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2002::2"))) assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::2"))) assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::3"))) } diff --git a/internal/dnsfilter/dnsfilter_test.go b/internal/dnsfilter/dnsfilter_test.go index 2bae12de07e..abac1fb9125 100644 --- a/internal/dnsfilter/dnsfilter_test.go +++ b/internal/dnsfilter/dnsfilter_test.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "net" - "strings" "testing" "github.com/AdguardTeam/AdGuardHome/internal/testutil" @@ -135,7 +134,7 @@ func TestEtcHostsMatching(t *testing.T) { assert.True(t, res.IsFiltered) if assert.Len(t, res.Rules, 1) { assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text) - assert.Len(t, res.Rules[0].IP, 0) + assert.Empty(t, res.Rules[0].IP) } // IPv6 @@ -147,7 +146,7 @@ func TestEtcHostsMatching(t *testing.T) { assert.True(t, res.IsFiltered) if assert.Len(t, res.Rules, 1) { assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text) - assert.Len(t, res.Rules[0].IP, 0) + assert.Empty(t, res.Rules[0].IP) } // 2 IPv4 (return only the first one) @@ -180,7 +179,7 @@ func TestSafeBrowsing(t *testing.T) { defer d.Close() d.checkMatch(t, "wmconvirus.narod.ru") - assert.True(t, strings.Contains(logOutput.String(), "SafeBrowsing lookup for wmconvirus.narod.ru")) + assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for wmconvirus.narod.ru") d.checkMatch(t, "test.wmconvirus.narod.ru") d.checkMatchEmpty(t, "yandex.ru") @@ -268,7 +267,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { res, err := d.CheckHost(domain, dns.TypeA, &setts) assert.Nil(t, err) assert.False(t, res.IsFiltered) - assert.Len(t, res.Rules, 0) + assert.Empty(t, res.Rules) d = NewForTest(&Config{SafeSearchEnabled: true}, nil) defer d.Close() @@ -298,7 +297,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) { res, err := d.CheckHost(domain, dns.TypeA, &setts) assert.Nil(t, err) assert.False(t, res.IsFiltered) - assert.Len(t, res.Rules, 0) + assert.Empty(t, res.Rules) d = NewForTest(&Config{SafeSearchEnabled: true}, nil) defer d.Close() @@ -346,7 +345,7 @@ func TestParentalControl(t *testing.T) { d := NewForTest(&Config{ParentalEnabled: true}, nil) defer d.Close() d.checkMatch(t, "pornhub.com") - assert.True(t, strings.Contains(logOutput.String(), "Parental lookup for pornhub.com")) + assert.Contains(t, logOutput.String(), "Parental lookup for pornhub.com") d.checkMatch(t, "www.pornhub.com") d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "yandex.ru") @@ -468,18 +467,20 @@ func TestWhitelist(t *testing.T) { // matched by white filter res, err := d.CheckHost("host1", dns.TypeA, &setts) - assert.True(t, err == nil) - assert.True(t, !res.IsFiltered && res.Reason == NotFilteredAllowList) + assert.Nil(t, err) + assert.False(t, res.IsFiltered) + assert.Equal(t, res.Reason, NotFilteredAllowList) if assert.Len(t, res.Rules, 1) { - assert.True(t, res.Rules[0].Text == "||host1^") + assert.Equal(t, "||host1^", res.Rules[0].Text) } // not matched by white filter, but matched by block filter res, err = d.CheckHost("host2", dns.TypeA, &setts) - assert.True(t, err == nil) - assert.True(t, res.IsFiltered && res.Reason == FilteredBlockList) + assert.Nil(t, err) + assert.True(t, res.IsFiltered) + assert.Equal(t, res.Reason, FilteredBlockList) if assert.Len(t, res.Rules, 1) { - assert.True(t, res.Rules[0].Text == "||host2^") + assert.Equal(t, "||host2^", res.Rules[0].Text) } } @@ -529,7 +530,7 @@ func TestClientSettings(t *testing.T) { // not blocked r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts) - assert.True(t, !r.IsFiltered) + assert.False(t, r.IsFiltered) // override client settings: applyClientSettings(&setts) @@ -554,7 +555,8 @@ func TestClientSettings(t *testing.T) { // blocked by additional rules r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts) - assert.True(t, r.IsFiltered && r.Reason == FilteredBlockedService) + assert.True(t, r.IsFiltered) + assert.Equal(t, r.Reason, FilteredBlockedService) } // BENCHMARKS diff --git a/internal/dnsfilter/dnsrewrite_test.go b/internal/dnsfilter/dnsrewrite_test.go index dadef40688c..201de44f531 100644 --- a/internal/dnsfilter/dnsrewrite_test.go +++ b/internal/dnsfilter/dnsrewrite_test.go @@ -171,7 +171,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { res, err := f.CheckHostRules(host, dtyp, setts) assert.Nil(t, err) - assert.Equal(t, "", res.CanonName) + assert.Empty(t, res.CanonName) if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) { assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode) @@ -197,7 +197,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { res, err := f.CheckHostRules(host, dtyp, setts) assert.Nil(t, err) - assert.Equal(t, "", res.CanonName) - assert.Len(t, res.Rules, 0) + assert.Empty(t, res.CanonName) + assert.Empty(t, res.Rules) }) } diff --git a/internal/dnsfilter/rewrites_test.go b/internal/dnsfilter/rewrites_test.go index 3a3284ec13d..a38f3f9d902 100644 --- a/internal/dnsfilter/rewrites_test.go +++ b/internal/dnsfilter/rewrites_test.go @@ -27,14 +27,14 @@ func TestRewrites(t *testing.T) { r = d.processRewrites("www.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, "host.com", r.CanonName) - assert.Equal(t, 2, len(r.IPList)) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) - assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5"))) + assert.Len(t, r.IPList, 2) + assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) + assert.True(t, r.IPList[1].Equal(net.IP{1, 2, 3, 5})) r = d.processRewrites("www.host.com", dns.TypeAAAA) assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, "host.com", r.CanonName) - assert.Equal(t, 1, len(r.IPList)) + assert.Len(t, r.IPList, 1) assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4"))) // wildcard @@ -45,11 +45,11 @@ func TestRewrites(t *testing.T) { d.prepareRewrites() r = d.processRewrites("host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) r = d.processRewrites("www.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5"))) + assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 5})) r = d.processRewrites("www.host2.com", dns.TypeA) assert.Equal(t, NotFilteredNotFound, r.Reason) @@ -62,8 +62,8 @@ func TestRewrites(t *testing.T) { d.prepareRewrites() r = d.processRewrites("a.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) - assert.True(t, len(r.IPList) == 1) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + assert.Len(t, r.IPList, 1) + assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) // wildcard + CNAME d.Rewrites = []RewriteEntry{ @@ -74,7 +74,7 @@ func TestRewrites(t *testing.T) { r = d.processRewrites("www.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, "host.com", r.CanonName) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) // 2 CNAMEs d.Rewrites = []RewriteEntry{ @@ -86,8 +86,8 @@ func TestRewrites(t *testing.T) { r = d.processRewrites("b.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, "host.com", r.CanonName) - assert.True(t, len(r.IPList) == 1) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + assert.Len(t, r.IPList, 1) + assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) // 2 CNAMEs + wildcard d.Rewrites = []RewriteEntry{ @@ -99,8 +99,8 @@ func TestRewrites(t *testing.T) { r = d.processRewrites("b.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, "x.somehost.com", r.CanonName) - assert.True(t, len(r.IPList) == 1) - assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) + assert.Len(t, r.IPList, 1) + assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4})) } func TestRewritesLevels(t *testing.T) { @@ -116,19 +116,19 @@ func TestRewritesLevels(t *testing.T) { // match exact r := d.processRewrites("host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, 1, len(r.IPList)) + assert.Len(t, r.IPList, 1) assert.Equal(t, "1.1.1.1", r.IPList[0].String()) // match L2 r = d.processRewrites("sub.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, 1, len(r.IPList)) + assert.Len(t, r.IPList, 1) assert.Equal(t, "2.2.2.2", r.IPList[0].String()) // match L3 r = d.processRewrites("my.sub.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, 1, len(r.IPList)) + assert.Len(t, r.IPList, 1) assert.Equal(t, "3.3.3.3", r.IPList[0].String()) } @@ -144,7 +144,7 @@ func TestRewritesExceptionCNAME(t *testing.T) { // match sub-domain r := d.processRewrites("my.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, 1, len(r.IPList)) + assert.Len(t, r.IPList, 1) assert.Equal(t, "2.2.2.2", r.IPList[0].String()) // match sub-domain, but handle exception @@ -164,7 +164,7 @@ func TestRewritesExceptionWC(t *testing.T) { // match sub-domain r := d.processRewrites("my.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, 1, len(r.IPList)) + assert.Len(t, r.IPList, 1) assert.Equal(t, "2.2.2.2", r.IPList[0].String()) // match sub-domain, but handle exception @@ -187,7 +187,7 @@ func TestRewritesExceptionIP(t *testing.T) { // match domain r := d.processRewrites("host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, 1, len(r.IPList)) + assert.Len(t, r.IPList, 1) assert.Equal(t, "1.2.3.4", r.IPList[0].String()) // match exception @@ -201,7 +201,7 @@ func TestRewritesExceptionIP(t *testing.T) { // match domain r = d.processRewrites("host2.com", dns.TypeAAAA) assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, 1, len(r.IPList)) + assert.Len(t, r.IPList, 1) assert.Equal(t, "::1", r.IPList[0].String()) // match exception @@ -211,5 +211,5 @@ func TestRewritesExceptionIP(t *testing.T) { // match domain r = d.processRewrites("host3.com", dns.TypeAAAA) assert.Equal(t, Rewritten, r.Reason) - assert.Equal(t, 0, len(r.IPList)) + assert.Empty(t, r.IPList) } diff --git a/internal/dnsfilter/safebrowsing.go b/internal/dnsfilter/safebrowsing.go index f5aaca9f379..8f4c8f306ac 100644 --- a/internal/dnsfilter/safebrowsing.go +++ b/internal/dnsfilter/safebrowsing.go @@ -37,8 +37,8 @@ func (d *DNSFilter) initSecurityServices() error { opts := upstream.Options{ Timeout: dnsTimeout, ServerIPAddrs: []net.IP{ - net.ParseIP("94.140.14.15"), - net.ParseIP("94.140.15.16"), + {94, 140, 14, 15}, + {94, 140, 15, 16}, net.ParseIP("2a10:50c0::bad1:ff"), net.ParseIP("2a10:50c0::bad2:ff"), }, diff --git a/internal/dnsfilter/safebrowsing_test.go b/internal/dnsfilter/safebrowsing_test.go index 71e59446cea..060664b5597 100644 --- a/internal/dnsfilter/safebrowsing_test.go +++ b/internal/dnsfilter/safebrowsing_test.go @@ -14,7 +14,7 @@ import ( func TestSafeBrowsingHash(t *testing.T) { // test hostnameToHashes() hashes := hostnameToHashes("1.2.3.sub.host.com") - assert.Equal(t, 3, len(hashes)) + assert.Len(t, hashes, 3) _, ok := hashes[sha256.Sum256([]byte("3.sub.host.com"))] assert.True(t, ok) _, ok = hashes[sha256.Sum256([]byte("sub.host.com"))] @@ -31,9 +31,9 @@ func TestSafeBrowsingHash(t *testing.T) { q := c.getQuestion() - assert.True(t, strings.Contains(q, "7a1b.")) - assert.True(t, strings.Contains(q, "af5a.")) - assert.True(t, strings.Contains(q, "eb11.")) + assert.Contains(t, q, "7a1b.") + assert.Contains(t, q, "af5a.") + assert.Contains(t, q, "eb11.") assert.True(t, strings.HasSuffix(q, "sb.dns.adguard.com.")) } @@ -81,7 +81,7 @@ func TestSafeBrowsingCache(t *testing.T) { c.hashToHost[hash] = "sub.host.com" hash = sha256.Sum256([]byte("nonexisting.com")) c.hashToHost[hash] = "nonexisting.com" - assert.Equal(t, 0, c.getCached()) + assert.Empty(t, c.getCached()) hash = sha256.Sum256([]byte("sub.host.com")) _, ok := c.hashToHost[hash] @@ -103,7 +103,7 @@ func TestSafeBrowsingCache(t *testing.T) { c.hashToHost[hash] = "sub.host.com" c.cache.Set(hash[0:2], make([]byte, 32)) - assert.Equal(t, 0, c.getCached()) + assert.Empty(t, c.getCached()) } // testErrUpstream implements upstream.Upstream interface for replacing real diff --git a/internal/dnsforward/access_test.go b/internal/dnsforward/access_test.go index 250b49316c5..5c225b21e98 100644 --- a/internal/dnsforward/access_test.go +++ b/internal/dnsforward/access_test.go @@ -8,28 +8,28 @@ import ( func TestIsBlockedIPAllowed(t *testing.T) { a := &accessCtx{} - assert.True(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil) == nil) + assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil)) disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1") assert.False(t, disallowed) - assert.Equal(t, "", disallowedRule) + assert.Empty(t, disallowedRule) disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2") assert.True(t, disallowed) - assert.Equal(t, "", disallowedRule) + assert.Empty(t, disallowedRule) disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1") assert.False(t, disallowed) - assert.Equal(t, "", disallowedRule) + assert.Empty(t, disallowedRule) disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1") assert.True(t, disallowed) - assert.Equal(t, "", disallowedRule) + assert.Empty(t, disallowedRule) } func TestIsBlockedIPDisallowed(t *testing.T) { a := &accessCtx{} - assert.True(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil) == nil) + assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil)) disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1") assert.True(t, disallowed) @@ -37,7 +37,7 @@ func TestIsBlockedIPDisallowed(t *testing.T) { disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2") assert.False(t, disallowed) - assert.Equal(t, "", disallowedRule) + assert.Empty(t, disallowedRule) disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1") assert.True(t, disallowed) @@ -45,7 +45,7 @@ func TestIsBlockedIPDisallowed(t *testing.T) { disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1") assert.False(t, disallowed) - assert.Equal(t, "", disallowedRule) + assert.Empty(t, disallowedRule) } func TestIsBlockedIPBlockedDomain(t *testing.T) { @@ -60,13 +60,13 @@ func TestIsBlockedIPBlockedDomain(t *testing.T) { // match by "host2.com" assert.True(t, a.IsBlockedDomain("host1")) assert.True(t, a.IsBlockedDomain("host2")) - assert.True(t, !a.IsBlockedDomain("host3")) + assert.False(t, a.IsBlockedDomain("host3")) // match by wildcard "*.host.com" - assert.True(t, !a.IsBlockedDomain("host.com")) + assert.False(t, a.IsBlockedDomain("host.com")) assert.True(t, a.IsBlockedDomain("asdf.host.com")) assert.True(t, a.IsBlockedDomain("qwer.asdf.host.com")) - assert.True(t, !a.IsBlockedDomain("asdf.zhost.com")) + assert.False(t, a.IsBlockedDomain("asdf.zhost.com")) // match by wildcard "||host3.com^" assert.True(t, a.IsBlockedDomain("host3.com")) diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 881174d1e20..0ff078f7875 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -29,17 +29,16 @@ type FilteringConfig struct { // GetCustomUpstreamByClient - a callback function that returns upstreams configuration // based on the client IP address. Returns nil if there are no custom upstreams for the client + // TODO(e.burkov): replace argument type with net.IP. GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"` // Protection configuration // -- - ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features - BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests - BlockingIPv4 string `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request - BlockingIPv6 string `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request - BlockingIPAddrv4 net.IP `yaml:"-"` - BlockingIPAddrv6 net.IP `yaml:"-"` + ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features + BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests + BlockingIPv4 net.IP `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request + BlockingIPv6 net.IP `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) // IP (or domain name) which is used to respond to DNS requests blocked by parental control or safe-browsing diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 4a47cdc15c0..10a965e25ea 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -182,7 +182,7 @@ func processInternalHosts(ctx *dnsContext) int { return resultDone } - log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip.String()) + log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip) resp := s.makeResponse(req) @@ -278,7 +278,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) int { return resultDone } -// Pass request to upstream servers; process the response +// processUpstream passes request to upstream servers and handles the response. func processUpstream(ctx *dnsContext) int { s := ctx.srv d := ctx.proxyCtx @@ -287,7 +287,7 @@ func processUpstream(ctx *dnsContext) int { } if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil { - clientIP := ipFromAddr(d.Addr) + clientIP := IPStringFromAddr(d.Addr) upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP) if upstreamsConf != nil { log.Debug("Using custom upstreams for %s", clientIP) diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index f1b7e7d2528..ab6bea275e7 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -178,9 +178,7 @@ func (s *Server) Prepare(config *ServerConfig) error { if config != nil { s.conf = *config if s.conf.BlockingMode == "custom_ip" { - s.conf.BlockingIPAddrv4 = net.ParseIP(s.conf.BlockingIPv4) - s.conf.BlockingIPAddrv6 = net.ParseIP(s.conf.BlockingIPv6) - if s.conf.BlockingIPAddrv4 == nil || s.conf.BlockingIPAddrv6 == nil { + if s.conf.BlockingIPv4 == nil || s.conf.BlockingIPv6 == nil { return fmt.Errorf("dns: invalid custom blocking IP address specified") } } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index fdee8648770..ab9bb03cfdd 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -286,7 +286,7 @@ func TestBlockedRequest(t *testing.T) { t.Fatalf("Couldn't talk to server %s: %s", addr, err) } assert.Equal(t, dns.RcodeSuccess, reply.Rcode) - assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0"))) + assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0})) err = s.Stop() if err != nil { @@ -300,7 +300,7 @@ func TestServerCustomClientUpstream(t *testing.T) { uc := &proxy.UpstreamConfig{} u := &testUpstream{} u.ipv4 = map[string][]net.IP{} - u.ipv4["host."] = []net.IP{net.ParseIP("192.168.0.1")} + u.ipv4["host."] = []net.IP{{192, 168, 0, 1}} uc.Upstreams = append(uc.Upstreams, u) return uc } @@ -425,7 +425,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} s.conf.ProtectionEnabled = false err := s.startWithUpstream(testUpstm) - assert.True(t, err == nil) + assert.Nil(t, err) addr := s.dnsProxy.Addr(proxy.ProtoUDP) // 'badhost' has a canonical name 'null.example.org' which is blocked by filters: @@ -440,16 +440,16 @@ func TestBlockCNAME(t *testing.T) { s := createTestServer(t) testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} err := s.startWithUpstream(testUpstm) - assert.True(t, err == nil) + assert.Nil(t, err) addr := s.dnsProxy.Addr(proxy.ProtoUDP) // 'badhost' has a canonical name 'null.example.org' which is blocked by filters: // response is blocked req := createTestMessage("badhost.") reply, err := dns.Exchange(req, addr.String()) - assert.Nil(t, err, nil) + assert.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) - assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0"))) + assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0})) // 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters // but 'whitelist.example.org' is in a whitelist: @@ -465,7 +465,7 @@ func TestBlockCNAME(t *testing.T) { reply, err = dns.Exchange(req, addr.String()) assert.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) - assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0"))) + assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0})) _ = s.Stop() } @@ -548,13 +548,13 @@ func TestBlockedCustomIP(t *testing.T) { conf.TCPListenAddr = &net.TCPAddr{Port: 0} conf.ProtectionEnabled = true conf.BlockingMode = "custom_ip" - conf.BlockingIPv4 = "bad IP" + conf.BlockingIPv4 = nil conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"} err := s.Prepare(&conf) - assert.True(t, err != nil) // invalid BlockingIPv4 + assert.NotNil(t, err) // invalid BlockingIPv4 - conf.BlockingIPv4 = "0.0.0.1" - conf.BlockingIPv6 = "::1" + conf.BlockingIPv4 = net.IP{0, 0, 0, 1} + conf.BlockingIPv6 = net.ParseIP("::1") err = s.Prepare(&conf) assert.Nil(t, err) err = s.Start() @@ -565,7 +565,7 @@ func TestBlockedCustomIP(t *testing.T) { req := createTestMessageWithType("null.example.org.", dns.TypeA) reply, err := dns.Exchange(req, addr.String()) assert.Nil(t, err) - assert.Equal(t, 1, len(reply.Answer)) + assert.Len(t, reply.Answer, 1) a, ok := reply.Answer[0].(*dns.A) assert.True(t, ok) assert.Equal(t, "0.0.0.1", a.A.String()) @@ -573,7 +573,7 @@ func TestBlockedCustomIP(t *testing.T) { req = createTestMessageWithType("null.example.org.", dns.TypeAAAA) reply, err = dns.Exchange(req, addr.String()) assert.Nil(t, err) - assert.Equal(t, 1, len(reply.Answer)) + assert.Len(t, reply.Answer, 1) a6, ok := reply.Answer[0].(*dns.AAAA) assert.True(t, ok) assert.Equal(t, "::1", a6.AAAA.String()) @@ -710,7 +710,7 @@ func TestRewrite(t *testing.T) { req := createTestMessageWithType("test.com.", dns.TypeA) reply, err := dns.Exchange(req, addr.String()) assert.Nil(t, err) - assert.Equal(t, 1, len(reply.Answer)) + assert.Len(t, reply.Answer, 1) a, ok := reply.Answer[0].(*dns.A) assert.True(t, ok) assert.Equal(t, "1.2.3.4", a.A.String()) @@ -718,12 +718,12 @@ func TestRewrite(t *testing.T) { req = createTestMessageWithType("test.com.", dns.TypeAAAA) reply, err = dns.Exchange(req, addr.String()) assert.Nil(t, err) - assert.Equal(t, 0, len(reply.Answer)) + assert.Empty(t, reply.Answer) req = createTestMessageWithType("alias.test.com.", dns.TypeA) reply, err = dns.Exchange(req, addr.String()) assert.Nil(t, err) - assert.Equal(t, 2, len(reply.Answer)) + assert.Len(t, reply.Answer, 2) assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, "1.2.3.4", reply.Answer[1].(*dns.A).A.String()) @@ -731,7 +731,7 @@ func TestRewrite(t *testing.T) { reply, err = dns.Exchange(req, addr.String()) assert.Nil(t, err) assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) // the original question is restored - assert.Equal(t, 2, len(reply.Answer)) + assert.Len(t, reply.Answer, 2) assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) @@ -765,7 +765,7 @@ func createTestServer(t *testing.T) *Server { s.conf.ConfigModified = func() {} err := s.Prepare(nil) - assert.True(t, err == nil) + assert.Nil(t, err) return s } @@ -1011,16 +1011,14 @@ func TestValidateUpstreamsSet(t *testing.T) { assert.NotNil(t, err, "there is an invalid upstream in set, but it pass through validation") } -func TestIpFromAddr(t *testing.T) { +func TestIPStringFromAddr(t *testing.T) { addr := net.UDPAddr{} addr.IP = net.ParseIP("1:2:3::4") addr.Port = 12345 addr.Zone = "eth0" - a := ipFromAddr(&addr) - assert.True(t, a == "1:2:3::4") + assert.Equal(t, IPStringFromAddr(&addr), net.ParseIP("1:2:3::4").String()) - a = ipFromAddr(nil) - assert.True(t, a == "") + assert.Empty(t, IPStringFromAddr(nil)) } func TestMatchDNSName(t *testing.T) { @@ -1030,9 +1028,9 @@ func TestMatchDNSName(t *testing.T) { assert.True(t, matchDNSName(dnsNames, "a.host2")) assert.True(t, matchDNSName(dnsNames, "b.a.host2")) assert.True(t, matchDNSName(dnsNames, "1.2.3.4")) - assert.True(t, !matchDNSName(dnsNames, "host2")) - assert.True(t, !matchDNSName(dnsNames, "")) - assert.True(t, !matchDNSName(dnsNames, "*.host2")) + assert.False(t, matchDNSName(dnsNames, "host2")) + assert.False(t, matchDNSName(dnsNames, "")) + assert.False(t, matchDNSName(dnsNames, "*.host2")) } type testDHCP struct { @@ -1040,7 +1038,7 @@ type testDHCP struct { func (d *testDHCP) Leases(flags int) []dhcpd.Lease { l := dhcpd.Lease{} - l.IP = net.ParseIP("127.0.0.1").To4() + l.IP = net.IP{127, 0, 0, 1} l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") l.Hostname = "localhost" return []dhcpd.Lease{l} @@ -1058,7 +1056,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.FilteringConfig.ProtectionEnabled = true err := s.Prepare(nil) - assert.True(t, err == nil) + assert.Nil(t, err) assert.Nil(t, s.Start()) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -1067,7 +1065,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { resp, err := dns.Exchange(req, addr.String()) assert.Nil(t, err) - assert.Equal(t, 1, len(resp.Answer)) + assert.Len(t, resp.Answer, 1) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) ptr := resp.Answer[0].(*dns.PTR) @@ -1100,7 +1098,7 @@ func TestPTRResponseFromHosts(t *testing.T) { s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.FilteringConfig.ProtectionEnabled = true err := s.Prepare(nil) - assert.True(t, err == nil) + assert.Nil(t, err) assert.Nil(t, s.Start()) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -1109,7 +1107,7 @@ func TestPTRResponseFromHosts(t *testing.T) { resp, err := dns.Exchange(req, addr.String()) assert.Nil(t, err) - assert.Equal(t, 1, len(resp.Answer)) + assert.Len(t, resp.Answer, 1) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) ptr := resp.Answer[0].(*dns.PTR) diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 80cf26dd551..77ae30a94d5 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -12,7 +12,7 @@ import ( ) func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) { - ip := ipFromAddr(d.Addr) + ip := IPStringFromAddr(d.Addr) disallowed, _ := s.access.IsBlockedIP(ip) if disallowed { log.Tracef("Client IP %s is blocked by settings", ip) @@ -36,7 +36,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt setts := s.dnsFilter.GetConfig() setts.FilteringEnabled = true if s.conf.FilterHandler != nil { - clientAddr := ipFromAddr(d.Addr) + clientAddr := IPStringFromAddr(d.Addr) s.conf.FilterHandler(clientAddr, &setts) } return &setts diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index e24ba89eb20..1f5780a61cf 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -28,8 +28,8 @@ type dnsConfig struct { ProtectionEnabled *bool `json:"protection_enabled"` RateLimit *uint32 `json:"ratelimit"` BlockingMode *string `json:"blocking_mode"` - BlockingIPv4 *string `json:"blocking_ipv4"` - BlockingIPv6 *string `json:"blocking_ipv6"` + BlockingIPv4 net.IP `json:"blocking_ipv4"` + BlockingIPv6 net.IP `json:"blocking_ipv6"` EDNSCSEnabled *bool `json:"edns_cs_enabled"` DNSSECEnabled *bool `json:"dnssec_enabled"` DisableIPv6 *bool `json:"disable_ipv6"` @@ -68,8 +68,8 @@ func (s *Server) getDNSConfig() dnsConfig { Bootstraps: &bootstraps, ProtectionEnabled: &protectionEnabled, BlockingMode: &blockingMode, - BlockingIPv4: &BlockingIPv4, - BlockingIPv6: &BlockingIPv6, + BlockingIPv4: BlockingIPv4, + BlockingIPv6: BlockingIPv6, RateLimit: &Ratelimit, EDNSCSEnabled: &EnableEDNSClientSubnet, DNSSECEnabled: &EnableDNSSEC, @@ -100,17 +100,11 @@ func (req *dnsConfig) checkBlockingMode() bool { bm := *req.BlockingMode if bm == "custom_ip" { - if req.BlockingIPv4 == nil || req.BlockingIPv6 == nil { + if req.BlockingIPv4.To4() == nil { return false } - ip4 := net.ParseIP(*req.BlockingIPv4) - if ip4 == nil || ip4.To4() == nil { - return false - } - - ip6 := net.ParseIP(*req.BlockingIPv6) - return ip6 != nil + return req.BlockingIPv6 != nil } for _, valid := range []string{ @@ -247,10 +241,8 @@ func (s *Server) setConfig(dc dnsConfig) (restart bool) { if dc.BlockingMode != nil { s.conf.BlockingMode = *dc.BlockingMode if *dc.BlockingMode == "custom_ip" { - s.conf.BlockingIPv4 = *dc.BlockingIPv4 - s.conf.BlockingIPAddrv4 = net.ParseIP(*dc.BlockingIPv4) - s.conf.BlockingIPv6 = *dc.BlockingIPv6 - s.conf.BlockingIPAddrv6 = net.ParseIP(*dc.BlockingIPv6) + s.conf.BlockingIPv4 = dc.BlockingIPv4.To4() + s.conf.BlockingIPv6 = dc.BlockingIPv6.To16() } } diff --git a/internal/dnsforward/msg.go b/internal/dnsforward/msg.go index 71497a6cbf1..ba95bbcef35 100644 --- a/internal/dnsforward/msg.go +++ b/internal/dnsforward/msg.go @@ -60,9 +60,9 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu switch m.Question[0].Qtype { case dns.TypeA: - return s.genARecord(m, s.conf.BlockingIPAddrv4) + return s.genARecord(m, s.conf.BlockingIPv4) case dns.TypeAAAA: - return s.genAAAARecord(m, s.conf.BlockingIPAddrv6) + return s.genAAAARecord(m, s.conf.BlockingIPv6) } } else if s.conf.BlockingMode == "nxdomain" { // means that we should return NXDOMAIN for any blocked request diff --git a/internal/dnsforward/stats.go b/internal/dnsforward/stats.go index c447be05c0b..822df6a0a7d 100644 --- a/internal/dnsforward/stats.go +++ b/internal/dnsforward/stats.go @@ -36,7 +36,7 @@ func processQueryLogsAndStats(ctx *dnsContext) int { OrigAnswer: ctx.origResp, Result: ctx.result, Elapsed: elapsed, - ClientIP: getIP(d.Addr), + ClientIP: ipFromAddr(d.Addr), } switch d.Proto { diff --git a/internal/dnsforward/util.go b/internal/dnsforward/util.go index da87f810a68..3a8c1cb3027 100644 --- a/internal/dnsforward/util.go +++ b/internal/dnsforward/util.go @@ -8,45 +8,32 @@ import ( "github.com/AdguardTeam/golibs/utils" ) -// GetIPString is a helper function that extracts IP address from net.Addr -func GetIPString(addr net.Addr) string { +// ipFromAddr gets IP address from addr. +func ipFromAddr(addr net.Addr) (ip net.IP) { switch addr := addr.(type) { case *net.UDPAddr: - return addr.IP.String() + return addr.IP case *net.TCPAddr: - return addr.IP.String() + return addr.IP } - return "" -} - -func stringArrayDup(a []string) []string { - a2 := make([]string, len(a)) - copy(a2, a) - return a2 + return nil } -// Get IP address from net.Addr object +// IPStringFromAddr extracts IP address from net.Addr. // Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone: // https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261 -func ipFromAddr(a net.Addr) string { - switch addr := a.(type) { - case *net.UDPAddr: - return addr.IP.String() - case *net.TCPAddr: - return addr.IP.String() +func IPStringFromAddr(addr net.Addr) (ipstr string) { + if ip := ipFromAddr(addr); ip != nil { + return ip.String() } + return "" } -// Get IP address from net.Addr -func getIP(addr net.Addr) net.IP { - switch addr := addr.(type) { - case *net.UDPAddr: - return addr.IP - case *net.TCPAddr: - return addr.IP - } - return nil +func stringArrayDup(a []string) []string { + a2 := make([]string, len(a)) + copy(a2, a) + return a2 } // Find value in a sorted array diff --git a/internal/home/auth_test.go b/internal/home/auth_test.go index 0998a2a6707..3811b170d1a 100644 --- a/internal/home/auth_test.go +++ b/internal/home/auth_test.go @@ -70,7 +70,7 @@ func TestAuth(t *testing.T) { a.Close() u := a.UserFind("name", "password") - assert.True(t, len(u.Name) != 0) + assert.NotEmpty(t, u.Name) time.Sleep(3 * time.Second) @@ -125,9 +125,9 @@ func TestAuthHTTP(t *testing.T) { r.URL = &url.URL{Path: "/"} handlerCalled = false handler2(&w, &r) - assert.True(t, w.statusCode == http.StatusFound) - assert.True(t, w.hdr.Get("Location") != "") - assert.True(t, !handlerCalled) + assert.Equal(t, http.StatusFound, w.statusCode) + assert.NotEmpty(t, w.hdr.Get("Location")) + assert.False(t, handlerCalled) // go to login page loginURL := w.hdr.Get("Location") @@ -139,7 +139,7 @@ func TestAuthHTTP(t *testing.T) { // perform login cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}) assert.Nil(t, err) - assert.True(t, cookie != "") + assert.NotEmpty(t, cookie) // get / handler2 = optionalAuth(handler) @@ -168,8 +168,8 @@ func TestAuthHTTP(t *testing.T) { r.URL = &url.URL{Path: loginURL} handlerCalled = false handler2(&w, &r) - assert.True(t, w.hdr.Get("Location") != "") - assert.True(t, !handlerCalled) + assert.NotEmpty(t, w.hdr.Get("Location")) + assert.False(t, handlerCalled) r.Header.Del("Cookie") // get login page with an invalid cookie diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index 9268c08f71b..69f2badabcd 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -37,15 +37,18 @@ func TestClients(t *testing.T) { assert.Nil(t, err) c, b = clients.Find("1.1.1.1") - assert.True(t, b && c.Name == "client1") + assert.True(t, b) + assert.Equal(t, c.Name, "client1") c, b = clients.Find("1:2:3::4") - assert.True(t, b && c.Name == "client1") + assert.True(t, b) + assert.Equal(t, c.Name, "client1") c, b = clients.Find("2.2.2.2") - assert.True(t, b && c.Name == "client2") + assert.True(t, b) + assert.Equal(t, c.Name, "client2") - assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile)) + assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile)) }) @@ -109,7 +112,7 @@ func TestClients(t *testing.T) { err := clients.Update("client1", c) assert.Nil(t, err) - assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile)) + assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) c = Client{ @@ -123,8 +126,8 @@ func TestClients(t *testing.T) { c, b := clients.Find("1.1.1.2") assert.True(t, b) - assert.True(t, c.Name == "client1-renamed") - assert.True(t, c.IDs[0] == "1.1.1.2") + assert.Equal(t, "client1-renamed", c.Name) + assert.Equal(t, "1.1.1.2", c.IDs[0]) assert.True(t, c.UseOwnSettings) assert.Nil(t, clients.list["client1"]) }) @@ -172,12 +175,12 @@ func TestClientsWhois(t *testing.T) { whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}} // set whois info on new client clients.SetWhoisInfo("1.1.1.255", whois) - assert.True(t, clients.ipHost["1.1.1.255"].WhoisInfo[0][1] == "orgname-val") + assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.255"].WhoisInfo[0][1]) // set whois info on existing auto-client _, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) clients.SetWhoisInfo("1.1.1.1", whois) - assert.True(t, clients.ipHost["1.1.1.1"].WhoisInfo[0][1] == "orgname-val") + assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.1"].WhoisInfo[0][1]) // Check that we cannot set whois info on a manually-added client c = Client{ @@ -186,7 +189,7 @@ func TestClientsWhois(t *testing.T) { } _, _ = clients.Add(c) clients.SetWhoisInfo("1.1.1.2", whois) - assert.True(t, clients.ipHost["1.1.1.2"] == nil) + assert.Nil(t, clients.ipHost["1.1.1.2"]) _ = clients.Del("client1") } @@ -272,6 +275,6 @@ func TestClientsCustomUpstream(t *testing.T) { config = clients.FindUpstreams("1.1.1.1") assert.NotNil(t, config) - assert.Equal(t, 1, len(config.Upstreams)) - assert.Equal(t, 1, len(config.DomainReservedUpstreams)) + assert.Len(t, config.Upstreams, 1) + assert.Len(t, config.DomainReservedUpstreams, 1) } diff --git a/internal/home/dns.go b/internal/home/dns.go index 1090d9be255..a988062954f 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -98,7 +98,7 @@ func isRunning() bool { } func onDNSRequest(d *proxy.DNSContext) { - ip := dnsforward.GetIPString(d.Addr) + ip := dnsforward.IPStringFromAddr(d.Addr) if ip == "" { // This would be quite weird if we get here return diff --git a/internal/home/filter_test.go b/internal/home/filter_test.go index 2bc23be1635..a5b6d20babf 100644 --- a/internal/home/filter_test.go +++ b/internal/home/filter_test.go @@ -50,16 +50,17 @@ func TestFilters(t *testing.T) { // download ok, err := Context.filters.update(&f) - assert.Equal(t, nil, err) + assert.Nil(t, err) assert.True(t, ok) assert.Equal(t, 3, f.RulesCount) // refresh ok, err = Context.filters.update(&f) - assert.True(t, !ok && err == nil) + assert.False(t, ok) + assert.Nil(t, err) err = Context.filters.load(&f) - assert.True(t, err == nil) + assert.Nil(t, err) f.unload() _ = os.Remove(f.Path()) diff --git a/internal/home/home_test.go b/internal/home/home_test.go index 1b16e35745b..b21d7d46200 100644 --- a/internal/home/home_test.go +++ b/internal/home/home_test.go @@ -119,7 +119,7 @@ func TestHome(t *testing.T) { fn := filepath.Join(dir, "AdGuardHome.yaml") // Prepare the test config - assert.True(t, ioutil.WriteFile(fn, []byte(yamlConf), 0o644) == nil) + assert.Nil(t, ioutil.WriteFile(fn, []byte(yamlConf), 0o644)) fn, _ = filepath.Abs(fn) config = configuration{} // the global variable is dirty because of the previous tests run @@ -138,11 +138,11 @@ func TestHome(t *testing.T) { } time.Sleep(100 * time.Millisecond) } - assert.Truef(t, err == nil, "%s", err) + assert.Nilf(t, err, "%s", err) assert.Equal(t, http.StatusOK, resp.StatusCode) resp, err = h.Get("http://127.0.0.1:3000/control/status") - assert.Truef(t, err == nil, "%s", err) + assert.Nilf(t, err, "%s", err) assert.Equal(t, http.StatusOK, resp.StatusCode) // test DNS over UDP @@ -159,16 +159,16 @@ func TestHome(t *testing.T) { req.RecursionDesired = true req.Question = []dns.Question{{Name: "static.adguard.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}} buf, err := req.Pack() - assert.True(t, err == nil, "%s", err) + assert.Nil(t, err) requestURL := "http://127.0.0.1:3000/dns-query?dns=" + base64.RawURLEncoding.EncodeToString(buf) resp, err = http.DefaultClient.Get(requestURL) - assert.True(t, err == nil, "%s", err) + assert.Nil(t, err) body, err := ioutil.ReadAll(resp.Body) - assert.True(t, err == nil, "%s", err) - assert.True(t, resp.StatusCode == http.StatusOK) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) response := dns.Msg{} err = response.Unpack(body) - assert.True(t, err == nil, "%s", err) + assert.Nil(t, err) addrs = nil proxyutil.AppendIPAddrs(&addrs, response.Answer) haveIP = len(addrs) != 0 diff --git a/internal/home/mobileconfig_test.go b/internal/home/mobileconfig_test.go index f5bf3f2b28e..1025fe939c6 100644 --- a/internal/home/mobileconfig_test.go +++ b/internal/home/mobileconfig_test.go @@ -23,7 +23,7 @@ func TestHandleMobileConfigDOH(t *testing.T) { _, err = plist.Unmarshal(w.Body.Bytes(), &mc) assert.Nil(t, err) - if assert.Equal(t, 1, len(mc.PayloadContent)) { + if assert.Len(t, mc.PayloadContent, 1) { assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) @@ -51,7 +51,7 @@ func TestHandleMobileConfigDOH(t *testing.T) { _, err = plist.Unmarshal(w.Body.Bytes(), &mc) assert.Nil(t, err) - if assert.Equal(t, 1, len(mc.PayloadContent)) { + if assert.Len(t, mc.PayloadContent, 1) { assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) @@ -89,7 +89,7 @@ func TestHandleMobileConfigDOT(t *testing.T) { _, err = plist.Unmarshal(w.Body.Bytes(), &mc) assert.Nil(t, err) - if assert.Equal(t, 1, len(mc.PayloadContent)) { + if assert.Len(t, mc.PayloadContent, 1) { assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) @@ -116,7 +116,7 @@ func TestHandleMobileConfigDOT(t *testing.T) { _, err = plist.Unmarshal(w.Body.Bytes(), &mc) assert.Nil(t, err) - if assert.Equal(t, 1, len(mc.PayloadContent)) { + if assert.Len(t, mc.PayloadContent, 1) { assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 399886c4871..516b0ed5d5c 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -12,10 +12,10 @@ func TestResolveRDNS(t *testing.T) { conf := &dnsforward.ServerConfig{} conf.UpstreamDNS = []string{"8.8.8.8"} err := dns.Prepare(conf) - assert.True(t, err == nil, "%s", err) + assert.Nil(t, err) clients := &clientsContainer{} rdns := InitRDNS(dns, clients) r := rdns.resolve("1.1.1.1") - assert.True(t, r == "one.one.one.one", "%s", r) + assert.Equal(t, "one.one.one.one", r, r) } diff --git a/internal/querylog/decode_test.go b/internal/querylog/decode_test.go index ffcf94dcd69..a599084d06f 100644 --- a/internal/querylog/decode_test.go +++ b/internal/querylog/decode_test.go @@ -84,7 +84,7 @@ func TestDecodeLogEntry(t *testing.T) { decodeLogEntry(got, data) s := logOutput.String() - assert.Equal(t, "", s) + assert.Empty(t, s) // Correct for time zones. got.Time = got.Time.UTC() @@ -172,7 +172,7 @@ func TestDecodeLogEntry(t *testing.T) { s := logOutput.String() if tc.want == "" { - assert.Equal(t, "", s) + assert.Empty(t, s) } else { assert.True(t, strings.HasSuffix(s, tc.want), "got %q", s) diff --git a/internal/querylog/qlog_test.go b/internal/querylog/qlog_test.go index dfd4e6cebf5..0fa072c118a 100644 --- a/internal/querylog/qlog_test.go +++ b/internal/querylog/qlog_test.go @@ -56,7 +56,7 @@ func TestQueryLog(t *testing.T) { // get all entries params := newSearchParams() entries, _ := l.search(params) - assert.Equal(t, 4, len(entries)) + assert.Len(t, entries, 4) assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") @@ -70,7 +70,7 @@ func TestQueryLog(t *testing.T) { value: "TEST.example.org", }) entries, _ = l.search(params) - assert.Equal(t, 1, len(entries)) + assert.Len(t, entries, 1) assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") // search by domain (not strict) @@ -81,7 +81,7 @@ func TestQueryLog(t *testing.T) { value: "example.ORG", }) entries, _ = l.search(params) - assert.Equal(t, 3, len(entries)) + assert.Len(t, entries, 3) assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2") assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1") @@ -94,7 +94,7 @@ func TestQueryLog(t *testing.T) { value: "2.2.2.2", }) entries, _ = l.search(params) - assert.Equal(t, 1, len(entries)) + assert.Len(t, entries, 1) assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2") // search by client IP (part of) @@ -105,7 +105,7 @@ func TestQueryLog(t *testing.T) { value: "2.2.2", }) entries, _ = l.search(params) - assert.Equal(t, 4, len(entries)) + assert.Len(t, entries, 4) assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") @@ -138,7 +138,7 @@ func TestQueryLogOffsetLimit(t *testing.T) { params.offset = 0 params.limit = 10 entries, _ := l.search(params) - assert.Equal(t, 10, len(entries)) + assert.Len(t, entries, 10) assert.Equal(t, entries[0].QHost, "first.example.org") assert.Equal(t, entries[9].QHost, "first.example.org") @@ -146,7 +146,7 @@ func TestQueryLogOffsetLimit(t *testing.T) { params.offset = 10 params.limit = 10 entries, _ = l.search(params) - assert.Equal(t, 10, len(entries)) + assert.Len(t, entries, 10) assert.Equal(t, entries[0].QHost, "second.example.org") assert.Equal(t, entries[9].QHost, "second.example.org") @@ -154,7 +154,7 @@ func TestQueryLogOffsetLimit(t *testing.T) { params.offset = 15 params.limit = 10 entries, _ = l.search(params) - assert.Equal(t, 5, len(entries)) + assert.Len(t, entries, 5) assert.Equal(t, entries[0].QHost, "second.example.org") assert.Equal(t, entries[4].QHost, "second.example.org") @@ -162,7 +162,7 @@ func TestQueryLogOffsetLimit(t *testing.T) { params.offset = 20 params.limit = 10 entries, _ = l.search(params) - assert.Equal(t, 0, len(entries)) + assert.Empty(t, entries) } func TestQueryLogMaxFileScanEntries(t *testing.T) { @@ -186,11 +186,11 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) { params := newSearchParams() params.maxFileScanEntries = 5 // do not scan more than 5 records entries, _ := l.search(params) - assert.Equal(t, 5, len(entries)) + assert.Len(t, entries, 5) params.maxFileScanEntries = 0 // disable the limit entries, _ = l.search(params) - assert.Equal(t, 10, len(entries)) + assert.Len(t, entries, 10) } func TestQueryLogFileDisabled(t *testing.T) { @@ -211,7 +211,7 @@ func TestQueryLogFileDisabled(t *testing.T) { params := newSearchParams() ll, _ := l.search(params) - assert.Equal(t, 2, len(ll)) + assert.Len(t, ll, 2) assert.Equal(t, "example3.org", ll[0].QHost) assert.Equal(t, "example2.org", ll[1].QHost) } @@ -262,7 +262,7 @@ func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) msg := new(dns.Msg) assert.Nil(t, msg.Unpack(entry.Answer)) - assert.Equal(t, 1, len(msg.Answer)) + assert.Len(t, msg.Answer, 1) ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]) assert.NotNil(t, ip) assert.Equal(t, answer, ip.String()) diff --git a/internal/querylog/qlogfile_test.go b/internal/querylog/qlogfile_test.go index 950eaaf3ae0..7d603ba5121 100644 --- a/internal/querylog/qlogfile_test.go +++ b/internal/querylog/qlogfile_test.go @@ -28,12 +28,12 @@ func TestQLogFileEmpty(t *testing.T) { // seek to the start pos, err := q.SeekStart() assert.Nil(t, err) - assert.Equal(t, int64(0), pos) + assert.EqualValues(t, 0, pos) // try reading anyway line, err := q.ReadNext() assert.Equal(t, io.EOF, err) - assert.Equal(t, "", line) + assert.Empty(t, line) } func TestQLogFileLarge(t *testing.T) { @@ -53,14 +53,14 @@ func TestQLogFileLarge(t *testing.T) { // seek to the start pos, err := q.SeekStart() assert.Nil(t, err) - assert.NotEqual(t, int64(0), pos) + assert.NotEqualValues(t, 0, pos) read := 0 var line string for err == nil { line, err = q.ReadNext() if err == nil { - assert.True(t, len(line) > 0) + assert.NotZero(t, len(line)) read++ } } @@ -109,10 +109,10 @@ func TestQLogFileSeekLargeFile(t *testing.T) { assert.Nil(t, err) // ALMOST the record we need timestamp := readQLogTimestamp(line) - 1 - assert.NotEqual(t, uint64(0), timestamp) + assert.NotEqualValues(t, 0, timestamp) _, depth, err := q.SeekTS(timestamp) assert.NotNil(t, err) - assert.True(t, depth <= int(math.Log2(float64(count))+3)) + assert.LessOrEqual(t, depth, int(math.Log2(float64(count))+3)) } func TestQLogFileSeekSmallFile(t *testing.T) { @@ -155,22 +155,22 @@ func TestQLogFileSeekSmallFile(t *testing.T) { assert.Nil(t, err) // ALMOST the record we need timestamp := readQLogTimestamp(line) - 1 - assert.NotEqual(t, uint64(0), timestamp) + assert.NotEqualValues(t, 0, timestamp) _, depth, err := q.SeekTS(timestamp) assert.NotNil(t, err) - assert.True(t, depth <= int(math.Log2(float64(count))+3)) + assert.LessOrEqual(t, depth, int(math.Log2(float64(count))+3)) } func testSeekLineQLogFile(t *testing.T, q *QLogFile, lineNumber int) { line, err := getQLogFileLine(q, lineNumber) assert.Nil(t, err) ts := readQLogTimestamp(line) - assert.NotEqual(t, uint64(0), ts) + assert.NotEqualValues(t, 0, ts) // try seeking to that line now pos, _, err := q.SeekTS(ts) assert.Nil(t, err) - assert.NotEqual(t, int64(0), pos) + assert.NotEqualValues(t, 0, pos) testLine, err := q.ReadNext() assert.Nil(t, err) @@ -207,27 +207,27 @@ func TestQLogFile(t *testing.T) { // seek to the start pos, err := q.SeekStart() assert.Nil(t, err) - assert.True(t, pos > 0) + assert.Greater(t, pos, int64(0)) // read first line line, err := q.ReadNext() assert.Nil(t, err) - assert.True(t, strings.Contains(line, "0.0.0.2"), line) + assert.Contains(t, line, "0.0.0.2") assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasSuffix(line, "}"), line) // read second line line, err = q.ReadNext() assert.Nil(t, err) - assert.Equal(t, int64(0), q.position) - assert.True(t, strings.Contains(line, "0.0.0.1"), line) + assert.EqualValues(t, 0, q.position) + assert.Contains(t, line, "0.0.0.1") assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasSuffix(line, "}"), line) // try reading again (there's nothing to read anymore) line, err = q.ReadNext() assert.Equal(t, io.EOF, err) - assert.Equal(t, "", line) + assert.Empty(t, line) } // prepareTestFile - prepares a test query log file with the specified number of lines diff --git a/internal/querylog/qlogreader_test.go b/internal/querylog/qlogreader_test.go index d9dfb3ea514..967e83965ad 100644 --- a/internal/querylog/qlogreader_test.go +++ b/internal/querylog/qlogreader_test.go @@ -21,7 +21,7 @@ func TestQLogReaderEmpty(t *testing.T) { assert.Nil(t, err) line, err := r.ReadNext() - assert.Equal(t, "", line) + assert.Empty(t, line) assert.Equal(t, io.EOF, err) } @@ -241,7 +241,7 @@ func testSeekLineQLogReader(t *testing.T, r *QLogReader, lineNumber int) { line, err := getQLogReaderLine(r, lineNumber) assert.Nil(t, err) ts := readQLogTimestamp(line) - assert.NotEqual(t, uint64(0), ts) + assert.NotEqualValues(t, 0, ts) // try seeking to that line now err = r.SeekTS(ts) diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index 3a4bed66b7f..47e687997d2 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -39,13 +39,13 @@ func TestStats(t *testing.T) { e := Entry{} e.Domain = "domain" - e.Client = net.ParseIP("127.0.0.1") + e.Client = net.IP{127, 0, 0, 1} e.Result = RFiltered e.Time = 123456 s.Update(e) e.Domain = "domain" - e.Client = net.ParseIP("127.0.0.1") + e.Client = net.IP{127, 0, 0, 1} e.Result = RNotFiltered e.Time = 123456 s.Update(e) @@ -64,23 +64,23 @@ func TestStats(t *testing.T) { assert.True(t, UIntArrayEquals(d["replaced_parental"].([]uint64), a)) m := d["top_queried_domains"].([]map[string]uint64) - assert.True(t, m[0]["domain"] == 1) + assert.EqualValues(t, 1, m[0]["domain"]) m = d["top_blocked_domains"].([]map[string]uint64) - assert.True(t, m[0]["domain"] == 1) + assert.EqualValues(t, 1, m[0]["domain"]) m = d["top_clients"].([]map[string]uint64) - assert.True(t, m[0]["127.0.0.1"] == 2) + assert.EqualValues(t, 2, m[0]["127.0.0.1"]) - assert.True(t, d["num_dns_queries"].(uint64) == 2) - assert.True(t, d["num_blocked_filtering"].(uint64) == 1) - assert.True(t, d["num_replaced_safebrowsing"].(uint64) == 0) - assert.True(t, d["num_replaced_safesearch"].(uint64) == 0) - assert.True(t, d["num_replaced_parental"].(uint64) == 0) - assert.True(t, d["avg_processing_time"].(float64) == 0.123456) + assert.EqualValues(t, 2, d["num_dns_queries"].(uint64)) + assert.EqualValues(t, 1, d["num_blocked_filtering"].(uint64)) + assert.EqualValues(t, 0, d["num_replaced_safebrowsing"].(uint64)) + assert.EqualValues(t, 0, d["num_replaced_safesearch"].(uint64)) + assert.EqualValues(t, 0, d["num_replaced_parental"].(uint64)) + assert.EqualValues(t, 0.123456, d["avg_processing_time"].(float64)) topClients := s.GetTopClientsIP(2) - assert.True(t, topClients[0] == "127.0.0.1") + assert.Equal(t, "127.0.0.1", topClients[0]) s.clear() s.Close() @@ -111,7 +111,7 @@ func TestLargeNumbers(t *testing.T) { } for i := 0; i != n; i++ { e.Domain = fmt.Sprintf("domain%d", i) - e.Client = net.ParseIP("127.0.0.1") + e.Client = net.IP{127, 0, 0, 1} e.Client[2] = byte((i & 0xff00) >> 8) e.Client[3] = byte(i & 0xff) e.Result = RNotFiltered @@ -121,7 +121,7 @@ func TestLargeNumbers(t *testing.T) { } d := s.getData() - assert.True(t, d["num_dns_queries"].(uint64) == uint64(int(hour)*n)) + assert.EqualValues(t, int(hour)*n, d["num_dns_queries"]) s.Close() os.Remove(conf.Filename) @@ -152,6 +152,6 @@ func aggregateDataPerDay(firstID uint32) int { func TestAggregateDataPerTimeUnit(t *testing.T) { for i := 0; i != 25; i++ { alen := aggregateDataPerDay(uint32(i)) - assert.True(t, alen == 30, "i=%d", i) + assert.Equalf(t, 30, alen, "i=%d", i) } } diff --git a/internal/sysutil/net.go b/internal/sysutil/net.go index 557dd8d77d7..0e3b448ebcb 100644 --- a/internal/sysutil/net.go +++ b/internal/sysutil/net.go @@ -19,12 +19,12 @@ func IfaceSetStaticIP(ifaceName string) (err error) { } // GatewayIP returns IP address of interface's gateway. -func GatewayIP(ifaceName string) string { +func GatewayIP(ifaceName string) net.IP { cmd := exec.Command("ip", "route", "show", "dev", ifaceName) log.Tracef("executing %s %v", cmd.Path, cmd.Args) d, err := cmd.Output() if err != nil || cmd.ProcessState.ExitCode() != 0 { - return "" + return nil } fields := strings.Fields(string(d)) @@ -32,13 +32,8 @@ func GatewayIP(ifaceName string) string { // "default" at first field and default gateway IP address at third // field. if len(fields) < 3 || fields[0] != "default" { - return "" + return nil } - ip := net.ParseIP(fields[2]) - if ip == nil { - return "" - } - - return fields[2] + return net.ParseIP(fields[2]) } diff --git a/internal/sysutil/net_linux.go b/internal/sysutil/net_linux.go index 5206f9fda31..06d27eb2845 100644 --- a/internal/sysutil/net_linux.go +++ b/internal/sysutil/net_linux.go @@ -129,7 +129,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) { return err } gatewayIP := GatewayIP(ifaceName) - add := updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, ip4.String()) + add := updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, ip4) body, err := ioutil.ReadFile("/etc/dhcpcd.conf") if err != nil { @@ -147,14 +147,14 @@ func ifaceSetStaticIP(ifaceName string) (err error) { // updateStaticIPdhcpcdConf sets static IP address for the interface by writing // into dhcpd.conf. -func updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string { +func updateStaticIPdhcpcdConf(ifaceName, ip string, gatewayIP, dnsIP net.IP) string { var body []byte add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n", ifaceName, ip) body = append(body, []byte(add)...) - if len(gatewayIP) != 0 { + if gatewayIP != nil { add = fmt.Sprintf("static routers=%s\n", gatewayIP) body = append(body, []byte(add)...) diff --git a/internal/sysutil/net_linux_test.go b/internal/sysutil/net_linux_test.go index 8cadbbb76d7..a9851cb2789 100644 --- a/internal/sysutil/net_linux_test.go +++ b/internal/sysutil/net_linux_test.go @@ -4,6 +4,7 @@ package sysutil import ( "bytes" + "net" "testing" "github.com/stretchr/testify/assert" @@ -96,7 +97,7 @@ func TestSetStaticIPdhcpcdConf(t *testing.T) { `static routers=192.168.0.1` + nl + `static domain_name_servers=192.168.0.2` + nl + nl - s := updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", "192.168.0.1", "192.168.0.2") + s := updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", net.IP{192, 168, 0, 1}, net.IP{192, 168, 0, 2}) assert.Equal(t, dhcpcdConf, s) // without gateway @@ -104,6 +105,6 @@ func TestSetStaticIPdhcpcdConf(t *testing.T) { `static ip_address=192.168.0.2/24` + nl + `static domain_name_servers=192.168.0.2` + nl + nl - s = updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", "", "192.168.0.2") + s = updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", nil, net.IP{192, 168, 0, 2}) assert.Equal(t, dhcpcdConf, s) } diff --git a/internal/util/autohosts_test.go b/internal/util/autohosts_test.go index 04911142eb5..393646c7b39 100644 --- a/internal/util/autohosts_test.go +++ b/internal/util/autohosts_test.go @@ -42,7 +42,7 @@ func TestAutoHostsResolution(t *testing.T) { // Existing host ips := ah.Process("localhost", dns.TypeA) assert.NotNil(t, ips) - assert.Equal(t, 1, len(ips)) + assert.Len(t, ips, 1) assert.Equal(t, net.ParseIP("127.0.0.1"), ips[0]) // Unknown host @@ -107,7 +107,7 @@ func TestAutoHostsFSNotify(t *testing.T) { // Check if we are notified about changes ips = ah.Process("newhost", dns.TypeA) assert.NotNil(t, ips) - assert.Equal(t, 1, len(ips)) + assert.Len(t, ips, 1) assert.Equal(t, "127.0.0.2", ips[0].String()) } diff --git a/internal/util/helpers_test.go b/internal/util/helpers_test.go index d5e9063747e..68ebbabd2f5 100644 --- a/internal/util/helpers_test.go +++ b/internal/util/helpers_test.go @@ -8,7 +8,8 @@ import ( func TestSplitNext(t *testing.T) { s := " a,b , c " - assert.True(t, SplitNext(&s, ',') == "a") - assert.True(t, SplitNext(&s, ',') == "b") - assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0) + assert.Equal(t, "a", SplitNext(&s, ',')) + assert.Equal(t, "b", SplitNext(&s, ',')) + assert.Equal(t, "c", SplitNext(&s, ',')) + assert.Empty(t, s) }