diff --git a/lib/netext/dialer_test.go b/lib/netext/dialer_test.go index 6f1a546532b0..ea4c0eded591 100644 --- a/lib/netext/dialer_test.go +++ b/lib/netext/dialer_test.go @@ -37,16 +37,22 @@ func (r testResolver) FetchOne(host string) (net.IP, error) { return r.hosts[hos func TestDialerAddr(t *testing.T) { dialer := newDialerWithResolver(net.Dialer{}, newResolver()) dialer.Hosts = map[string]*lib.HostAddress{ - "example.com": {IP: net.ParseIP("3.4.5.6")}, - "example.com:443": {IP: net.ParseIP("3.4.5.6"), Port: 8443}, - "example.com:8080": {IP: net.ParseIP("3.4.5.6"), Port: 9090}, - "example-deny-host.com": {IP: net.ParseIP("8.9.10.11")}, + "example.com": {IP: net.ParseIP("3.4.5.6")}, + "example.com:443": {IP: net.ParseIP("3.4.5.6"), Port: 8443}, + "example.com:8080": {IP: net.ParseIP("3.4.5.6"), Port: 9090}, + "example-deny-host.com": {IP: net.ParseIP("8.9.10.11")}, + "example-ipv6.com": {IP: net.ParseIP("2001:db8::68")}, + "example-ipv6.com:443": {IP: net.ParseIP("2001:db8::68"), Port: 8443}, + "example-ipv6-deny-host.com": {IP: net.ParseIP("::1")}, } ipNet, err := lib.ParseCIDR("8.9.10.0/24") assert.NoError(t, err) - dialer.Blacklist = []*lib.IPNet{ipNet} + ipV6Net, err := lib.ParseCIDR("::1/24") + assert.NoError(t, err) + + dialer.Blacklist = []*lib.IPNet{ipNet, ipV6Net} addr, err := dialer.dialAddr("example-resolver.com:80") assert.NoError(t, err) @@ -64,18 +70,33 @@ func TestDialerAddr(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "3.4.5.6:9090", addr) + addr, err = dialer.dialAddr("example-ipv6.com:80") + assert.NoError(t, err) + assert.Equal(t, "[2001:db8::68]:80", addr) + + addr, err = dialer.dialAddr("example-ipv6.com:443") + assert.NoError(t, err) + assert.Equal(t, "[2001:db8::68]:8443", addr) + _, err = dialer.dialAddr("example-deny-resolver.com:80") assert.EqualError(t, err, "IP (8.9.10.11) is in a blacklisted range (8.9.10.0/24)") _, err = dialer.dialAddr("example-deny-host.com:80") assert.EqualError(t, err, "IP (8.9.10.11) is in a blacklisted range (8.9.10.0/24)") + + _, err = dialer.dialAddr("example-ipv6-deny-resolver.com:80") + assert.EqualError(t, err, "IP (::1) is in a blacklisted range (::/24)") + + _, err = dialer.dialAddr("example-ipv6-deny-host.com:80") + assert.EqualError(t, err, "IP (::1) is in a blacklisted range (::/24)") } func newResolver() testResolver { return testResolver{ hosts: map[string]net.IP{ - "example-resolver.com": net.ParseIP("1.2.3.4"), - "example-deny-resolver.com": net.ParseIP("8.9.10.11"), + "example-resolver.com": net.ParseIP("1.2.3.4"), + "example-deny-resolver.com": net.ParseIP("8.9.10.11"), + "example-ipv6-deny-resolver.com": net.ParseIP("::1"), }, } } diff --git a/lib/options.go b/lib/options.go index 18142b45453d..ca00b16340ca 100644 --- a/lib/options.go +++ b/lib/options.go @@ -21,6 +21,7 @@ package lib import ( + "bytes" "crypto/tls" "encoding/json" "fmt" @@ -40,6 +41,8 @@ import ( // iterations+vus, or stages) const DefaultScenarioName = "default" +const defaultHostPort = 80 + // DefaultSummaryTrendStats are the default trend columns shown in the test summary output // nolint: gochecknoglobals var DefaultSummaryTrendStats = []string{"avg", "min", "med", "max", "p(90)", "p(95)"} @@ -223,17 +226,11 @@ func (h *HostAddress) UnmarshalText(text []byte) error { return &net.ParseError{Type: "IP address", Text: ""} } - s := string(text) - host, port, err := net.SplitHostPort(s) + ip, port, err := splitHostPort(text) if err != nil { return err } - ip := net.ParseIP(host) - if ip == nil { - return &net.ParseError{Type: "IP address", Text: s} - } - nh, err := NewHostAddress(ip, port) if err != nil { return err @@ -243,6 +240,32 @@ func (h *HostAddress) UnmarshalText(text []byte) error { return nil } +func splitHostPort(text []byte) (net.IP, string, error) { + host := string(text) + var port string + + if isHostPort(text) { + var err error + host, port, err = net.SplitHostPort(host) + if err != nil { + return nil, "", err + } + } + + ip := net.ParseIP(host) + if ip == nil { + return nil, "", &net.ParseError{Type: "IP address", Text: host} + } + + return ip, port, nil +} + +func isHostPort(text []byte) bool { + return bytes.ContainsRune(text, ':') && + (bytes.ContainsRune(text, '.') || // ipV4 + (bytes.ContainsRune(text, '[') && bytes.ContainsRune(text, ']'))) // ipV6 +} + // ParseCIDR creates an IPNet out of a CIDR string func ParseCIDR(s string) (*IPNet, error) { _, ipnet, err := net.ParseCIDR(s) diff --git a/lib/options_test.go b/lib/options_test.go index 8abe98d78181..872a3274e202 100644 --- a/lib/options_test.go +++ b/lib/options_test.go @@ -485,7 +485,7 @@ func TestCIDRUnmarshal(t *testing.T) { testData := []struct { input string expectedOutput *IPNet - expactFailure bool + expectFailure bool }{ { "10.0.0.0/8", @@ -514,7 +514,7 @@ func TestCIDRUnmarshal(t *testing.T) { actualIPNet := &IPNet{} err := actualIPNet.UnmarshalText([]byte(data.input)) - if data.expactFailure { + if data.expectFailure { require.EqualError(t, err, "Failed to parse CIDR: invalid CIDR address: "+data.input) } else { require.NoError(t, err) @@ -523,3 +523,62 @@ func TestCIDRUnmarshal(t *testing.T) { }) } } + +func TestHostAddressUnmarshal(t *testing.T) { + var testData = []struct { + input string + expectedOutput *HostAddress + expectFailure string + }{ + { + "1.2.3.4", + &HostAddress{IP: net.ParseIP("1.2.3.4")}, + "", + }, + { + "1.2.3.4:80", + &HostAddress{IP: net.ParseIP("1.2.3.4"), Port: 80}, + "", + }, + { + "1.2.3.4:asdf", + nil, + "strconv.Atoi: parsing \"asdf\": invalid syntax", + }, + { + "2001:0db8:0000:0000:0000:ff00:0042:8329", + &HostAddress{IP: net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329")}, + "", + }, + { + "2001:db8::68", + &HostAddress{IP: net.ParseIP("2001:db8::68")}, + "", + }, + { + "[2001:db8::68]:80", + &HostAddress{IP: net.ParseIP("2001:db8::68"), Port: 80}, + "", + }, + { + "[2001:db8::68]:asdf", + nil, + "strconv.Atoi: parsing \"asdf\": invalid syntax", + }, + } + + for _, data := range testData { + data := data + t.Run(data.input, func(t *testing.T) { + actualHost := &HostAddress{} + err := actualHost.UnmarshalText([]byte(data.input)) + + if data.expectFailure != "" { + require.EqualError(t, err, data.expectFailure) + } else { + require.NoError(t, err) + assert.Equal(t, data.expectedOutput, actualHost) + } + }) + } +}