diff --git a/cmd/options.go b/cmd/options.go index decdfe453d2..9977d8757d4 100644 --- a/cmd/options.go +++ b/cmd/options.go @@ -63,6 +63,7 @@ func optionFlagSet() *pflag.FlagSet { flags.Duration("min-iteration-duration", 0, "minimum amount of time k6 will take executing a single iteration") flags.BoolP("throw", "w", false, "throw warnings (like failed http requests) as errors") flags.StringSlice("blacklist-ip", nil, "blacklist an `ip range` from being called") + flags.StringSlice("block-hostname", nil, "block a case-insensitive hostname `pattern`, with optional leading wildcard, from being called") // The comment about system-tags also applies for summary-trend-stats. The default values // are set in applyDefault(). @@ -151,6 +152,19 @@ func getOptions(flags *pflag.FlagSet) (lib.Options, error) { opts.BlacklistIPs = append(opts.BlacklistIPs, net) } + blockedHostnameStrings, err := flags.GetStringSlice("block-hostname") + if err != nil { + return opts, err + } + if len(blockedHostnameStrings) > 0 { + opts.BlockedHostnames = &lib.HostnameTrie{} + } + for _, s := range blockedHostnameStrings { + if insertErr := opts.BlockedHostnames.Insert(s); insertErr != nil { + return opts, errors.Wrap(insertErr, "block-hostname") + } + } + if flags.Changed("summary-trend-stats") { trendStats, errSts := flags.GetStringSlice("summary-trend-stats") if errSts != nil { diff --git a/js/runner.go b/js/runner.go index 03da8c07bd8..e2a2f14edf3 100644 --- a/js/runner.go +++ b/js/runner.go @@ -154,10 +154,11 @@ func (r *Runner) newVU(samplesOut chan<- stats.SampleContainer) (*VU, error) { } dialer := &netext.Dialer{ - Dialer: r.BaseDialer, - Resolver: r.Resolver, - Blacklist: r.Bundle.Options.BlacklistIPs, - Hosts: r.Bundle.Options.Hosts, + Dialer: r.BaseDialer, + Resolver: r.Resolver, + Blacklist: r.Bundle.Options.BlacklistIPs, + BlockedHostnames: r.Bundle.Options.BlockedHostnames, + Hosts: r.Bundle.Options.Hosts, } tlsConfig := &tls.Config{ InsecureSkipVerify: r.Bundle.Options.InsecureSkipTLSVerify.Bool, diff --git a/js/runner_test.go b/js/runner_test.go index 13f2c3a0eda..f4f9f261719 100644 --- a/js/runner_test.go +++ b/js/runner_test.go @@ -818,6 +818,75 @@ func TestVUIntegrationBlacklistScript(t *testing.T) { } } +func TestVUIntegrationBlockHostnamesOption(t *testing.T) { + r1, err := getSimpleRunner("/script.js", ` + import http from "k6/http"; + export default function() { http.get("https://k6.io/"); } + `) + require.NoError(t, err) + + hostnames := lib.HostnameTrie{} + require.NoError(t, hostnames.Insert("*.io")) + require.NoError(t, r1.SetOptions(lib.Options{ + Throw: null.BoolFrom(true), + BlockedHostnames: &hostnames, + })) + + r2, err := NewFromArchive(r1.MakeArchive(), lib.RuntimeOptions{}) + if !assert.NoError(t, err) { + return + } + + runners := map[string]*Runner{"Source": r1, "Archive": r2} + + for name, r := range runners { + r := r + t.Run(name, func(t *testing.T) { + vu, err := r.NewVU(make(chan stats.SampleContainer, 100)) + require.NoError(t, err) + err = vu.RunOnce(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "hostname (k6.io) is in a blocked pattern (*.io)") + }) + } +} + +func TestVUIntegrationBlockHostnamesScript(t *testing.T) { + r1, err := getSimpleRunner("/script.js", ` + import http from "k6/http"; + + export let options = { + throw: true, + blockHostnames: ["*.io"], + }; + + export default function() { http.get("https://k6.io/"); } + `) + if !assert.NoError(t, err) { + return + } + + r2, err := NewFromArchive(r1.MakeArchive(), lib.RuntimeOptions{}) + if !assert.NoError(t, err) { + return + } + + runners := map[string]*Runner{"Source": r1, "Archive": r2} + + for name, r := range runners { + r := r + t.Run(name, func(t *testing.T) { + vu, err := r.NewVU(make(chan stats.SampleContainer, 100)) + if !assert.NoError(t, err) { + return + } + err = vu.RunOnce(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "hostname (k6.io) is in a blocked pattern (*.io)") + }) + } +} + func TestVUIntegrationHosts(t *testing.T) { tb := httpmultibin.NewHTTPMultiBin(t) defer tb.Cleanup() diff --git a/lib/netext/dialer.go b/lib/netext/dialer.go index cecc4e735a7..72ee1d9fff7 100644 --- a/lib/netext/dialer.go +++ b/lib/netext/dialer.go @@ -40,9 +40,10 @@ import ( type Dialer struct { net.Dialer - Resolver *dnscache.Resolver - Blacklist []*lib.IPNet - Hosts map[string]net.IP + Resolver *dnscache.Resolver + Blacklist []*lib.IPNet + BlockedHostnames *lib.HostnameTrie + Hosts map[string]net.IP BytesRead int64 BytesWritten int64 @@ -66,11 +67,27 @@ func (b BlackListedIPError) Error() string { return fmt.Sprintf("IP (%s) is in a blacklisted range (%s)", b.ip, b.net) } +// BlockedHostError is returned when a given hostname is blocked +type BlockedHostError struct { + hostname string + match string +} + +func (b BlockedHostError) Error() string { + return fmt.Sprintf("hostname (%s) is in a blocked pattern (%s)", b.hostname, b.match) +} + // DialContext wraps the net.Dialer.DialContext and handles the k6 specifics func (d *Dialer) DialContext(ctx context.Context, proto, addr string) (net.Conn, error) { delimiter := strings.LastIndex(addr, ":") host := addr[:delimiter] + if d.BlockedHostnames != nil { + if match, blocked := d.BlockedHostnames.Contains(host); blocked { + return nil, BlockedHostError{hostname: host, match: match} + } + } + // lookup for domain defined in Hosts option before trying to resolve DNS. ip, ok := d.Hosts[host] if !ok { diff --git a/lib/netext/httpext/error_codes.go b/lib/netext/httpext/error_codes.go index b240d6841c2..6cb204b7d48 100644 --- a/lib/netext/httpext/error_codes.go +++ b/lib/netext/httpext/error_codes.go @@ -46,9 +46,10 @@ const ( defaultErrorCode errCode = 1000 defaultNetNonTCPErrorCode errCode = 1010 // DNS errors - defaultDNSErrorCode errCode = 1100 - dnsNoSuchHostErrorCode errCode = 1101 - blackListedIPErrorCode errCode = 1110 + defaultDNSErrorCode errCode = 1100 + dnsNoSuchHostErrorCode errCode = 1101 + blackListedIPErrorCode errCode = 1110 + blockedHostnameErrorCode errCode = 1111 // tcp errors defaultTCPErrorCode errCode = 1200 tcpBrokenPipeErrorCode errCode = 1201 @@ -90,6 +91,7 @@ const ( netUnknownErrnoErrorCodeMsg = "%s: unknown errno `%d` on %s with message `%s`" dnsNoSuchHostErrorCodeMsg = "lookup: no such host" blackListedIPErrorCodeMsg = "ip is blacklisted" + blockedHostnameErrorMsg = "hostname is blocked" http2GoAwayErrorCodeMsg = "http2: received GoAway with http2 ErrCode %s" http2StreamErrorCodeMsg = "http2: stream error with http2 ErrCode %s" http2ConnectionErrorCodeMsg = "http2: connection error with http2 ErrCode %s" @@ -118,6 +120,8 @@ func errorCodeForError(err error) (errCode, string) { } case netext.BlackListedIPError: return blackListedIPErrorCode, blackListedIPErrorCodeMsg + case netext.BlockedHostError: + return blockedHostnameErrorCode, blockedHostnameErrorMsg case *http2.GoAwayError: return unknownHTTP2GoAwayErrorCode + http2ErrCodeOffset(e.ErrCode), fmt.Sprintf(http2GoAwayErrorCodeMsg, e.ErrCode) diff --git a/lib/options.go b/lib/options.go index f94341a3ff3..9a7e3495a89 100644 --- a/lib/options.go +++ b/lib/options.go @@ -26,6 +26,8 @@ import ( "fmt" "net" "reflect" + "regexp" + "strings" "github.com/loadimpact/k6/lib/scheduler" "github.com/loadimpact/k6/lib/types" @@ -187,6 +189,103 @@ func ParseCIDR(s string) (*IPNet, error) { return &parsedIPNet, nil } +// HostnameTrie is a tree-structured list of hostname matches with support +// for wildcards exclusively at the start of the pattern. Items may only +// be inserted and searched. Internationalized hostnames are valid. +type HostnameTrie struct { + children map[rune]*HostnameTrie +} + +// Regex description of hostname pattern to enforce blocks by. Global var +// to avoid compilation penalty at runtime. +// Matches against strings composed entirely of letters, numbers, or '.'s +// with an optional wildcard at the start. +//nolint:gochecknoglobals +var legalHostnamePattern *regexp.Regexp = regexp.MustCompile(`^\*?(\pL|[0-9\.])*`) + +func legalHostname(s string) error { + if len(legalHostnamePattern.FindString(s)) != len(s) { + return errors.Errorf("invalid hostname pattern %s", s) + } + return nil +} + +// UnmarshalJSON forms a HostnameTrie from the provided hostname pattern +// list. +func (t *HostnameTrie) UnmarshalJSON(data []byte) error { + m := make([]string, 0) + if err := json.Unmarshal(data, &m); err != nil { + return err + } + for _, h := range m { + if insertErr := t.Insert(h); insertErr != nil { + return insertErr + } + } + return nil +} + +// UnmarshalText forms a HostnameTrie from a comma-delimited list +// of hostname patterns. +func (t *HostnameTrie) UnmarshalText(b []byte) error { + for _, s := range strings.Split(string(b), ",") { + if err := t.Insert(s); err != nil { + return err + } + } + return nil +} + +// Insert a hostname pattern into the given HostnameTrie. Returns an error +// if hostname pattern is illegal. +func (t *HostnameTrie) Insert(s string) error { + s = strings.ToLower(s) + if len(s) == 0 { + return nil + } + + if err := legalHostname(s); err != nil { + return err + } + + // mask creation of the trie by initializing the root here + if t.children == nil { + t.children = make(map[rune]*HostnameTrie) + } + + rStr := []rune(s) // need to iterate by runes for intl' names + last := len(rStr) - 1 + if c, ok := t.children[rStr[last]]; ok { + return c.Insert(string(rStr[:last])) + } + + t.children[rStr[last]] = &HostnameTrie{make(map[rune]*HostnameTrie)} + return t.children[rStr[last]].Insert(string(rStr[:last])) +} + +// Contains returns whether s matches a pattern in the HostnameTrie +// along with the matching pattern, if one was found. +func (t *HostnameTrie) Contains(s string) (matchedPattern string, matchFound bool) { + s = strings.ToLower(s) + if len(s) == 0 { + return s, len(t.children) == 0 + } + + rStr := []rune(s) + last := len(rStr) - 1 + if c, ok := t.children[rStr[last]]; ok { + if match, matched := c.Contains(string(rStr[:last])); matched { + return match + string(rStr[last]), true + } + } + + if _, wild := t.children['*']; wild { + return "*", true + } + + return "", false +} + type Options struct { // Should the test start in a paused state? Paused null.Bool `json:"paused" envconfig:"K6_PAUSED"` @@ -242,6 +341,9 @@ type Options struct { // Blacklist IP ranges that tests may not contact. Mainly useful in hosted setups. BlacklistIPs []*IPNet `json:"blacklistIPs" envconfig:"K6_BLACKLIST_IPS"` + // Block hostname patterns that tests may not contact. + BlockedHostnames *HostnameTrie `json:"blockHostnames" envconfig:"K6_BLOCK_HOSTNAMES"` + // Hosts overrides dns entries for given hosts Hosts map[string]net.IP `json:"hosts" envconfig:"K6_HOSTS"` @@ -389,6 +491,9 @@ func (o Options) Apply(opts Options) Options { if opts.BlacklistIPs != nil { o.BlacklistIPs = opts.BlacklistIPs } + if opts.BlockedHostnames != nil { + o.BlockedHostnames = opts.BlockedHostnames + } if opts.Hosts != nil { o.Hosts = opts.Hosts } diff --git a/lib/options_test.go b/lib/options_test.go index c7d86928636..23ff4d44d4a 100644 --- a/lib/options_test.go +++ b/lib/options_test.go @@ -318,6 +318,34 @@ func TestOptions(t *testing.T) { assert.Equal(t, net.IPv4zero, opts.BlacklistIPs[0].IP) assert.Equal(t, net.CIDRMask(1, 1), opts.BlacklistIPs[0].Mask) }) + t.Run("BlockedHostnames", func(t *testing.T) { + hostnames := HostnameTrie{} + assert.NoError(t, hostnames.Insert("test.k6.io")) + assert.Error(t, hostnames.Insert("inval*d.pattern")) + assert.NoError(t, hostnames.Insert("*valid.pattern")) + opts := Options{}.Apply(Options{ + BlockedHostnames: &hostnames, + }) + assert.NotNil(t, opts.BlockedHostnames) + assert.NotEmpty(t, opts.BlockedHostnames) + _, matches := opts.BlockedHostnames.Contains("K6.Io") + assert.False(t, matches) + match, matches := opts.BlockedHostnames.Contains("tEsT.k6.Io") + assert.True(t, matches) + assert.Equal(t, "test.k6.io", match) + match, matches = opts.BlockedHostnames.Contains("TEST.K6.IO") + assert.True(t, matches) + assert.Equal(t, "test.k6.io", match) + match, matches = opts.BlockedHostnames.Contains("blocked.valId.paTtern") + assert.True(t, matches) + assert.Equal(t, "*valid.pattern", match) + _, matches = opts.BlockedHostnames.Contains("example.test.k6.io") + assert.False(t, matches) + assert.NoError(t, opts.BlockedHostnames.Insert("*.test.k6.io")) + match, matches = opts.BlockedHostnames.Contains("example.test.k6.io") + assert.True(t, matches) + assert.Equal(t, "*.test.k6.io", match) + }) t.Run("Hosts", func(t *testing.T) { opts := Options{}.Apply(Options{Hosts: map[string]net.IP{