-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature/override_resolvers: add -resolvers flag to attack command
This adds the ability to override the host operating system resolver by allowing this to be specified on the command line (comma separated list). In the event more than one is specified, the resolvers are tried in order in a round robin manner.
- Loading branch information
1 parent
43c736a
commit e35c4f6
Showing
3 changed files
with
211 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package vegeta | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"net" | ||
"strconv" | ||
"strings" | ||
"sync/atomic" | ||
) | ||
|
||
type resolver struct { | ||
addresses []string | ||
dialer *net.Dialer | ||
idx uint64 | ||
} | ||
|
||
// NewResolver - create a new instance of a dns resolver for plugging | ||
// into net.DefaultResolver. Addresses should be a list of | ||
// ip addresses and optional port numbers, separated by colon. | ||
// For example: 1.2.3.4:53 and 1.2.3.4 are both valid. In the absence | ||
// of a port number, 53 will be used instead. | ||
func NewResolver(addresses []string) (*net.Resolver, error) { | ||
normalAddresses, err := normalizeResolverAddresses(addresses) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &net.Resolver{ | ||
PreferGo: true, | ||
Dial: (&resolver{addresses: normalAddresses, dialer: &net.Dialer{}}).dial, | ||
}, nil | ||
} | ||
|
||
func normalizeResolverAddresses(addresses []string) ([]string, error) { | ||
if len(addresses) == 0 { | ||
return nil, errors.New("must specify at least resolver address") | ||
} | ||
normalAddresses := make([]string, len(addresses)) | ||
for i, addr := range addresses { | ||
ipPort := strings.Split(addr, ":") | ||
port := 53 | ||
var host string | ||
|
||
switch len(ipPort) { | ||
case 2: | ||
pu16, err := strconv.ParseUint(ipPort[1], 10, 16) | ||
if err != nil { | ||
return nil, err | ||
} | ||
port = int(pu16) | ||
fallthrough | ||
case 1: | ||
host = ipPort[0] | ||
default: | ||
return nil, fmt.Errorf("invalid ip:port specified: %s", addr) | ||
|
||
} | ||
ip := net.ParseIP(host) | ||
if ip == nil { | ||
return nil, fmt.Errorf("host %s is not an IP address", host) | ||
} | ||
|
||
normalAddresses[i] = fmt.Sprintf("%s:%d", host, port) | ||
} | ||
return normalAddresses, nil | ||
} | ||
|
||
func (r *resolver) dial(ctx context.Context, network, _ string) (net.Conn, error) { | ||
return r.dialer.DialContext(ctx, network, r.address()) | ||
} | ||
|
||
func (r *resolver) address() string { | ||
var address string | ||
if l := uint64(len(r.addresses)); l > 1 { | ||
idx := atomic.AddUint64(&r.idx, 1) | ||
address = r.addresses[idx%l] | ||
} else { | ||
address = r.addresses[0] | ||
} | ||
return address | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
package vegeta | ||
|
||
import ( | ||
"net" | ||
"net/http" | ||
"testing" | ||
) | ||
|
||
func TestResolve8888(t *testing.T) { | ||
r, err := NewResolver([]string{"8.8.8.8:53"}) | ||
if err != nil { | ||
t.FailNow() | ||
} | ||
|
||
net.DefaultResolver = r | ||
|
||
resp, err := http.Get("https://www.google.com/") | ||
|
||
if err != nil { | ||
t.Logf("error from http.Get(): %s", err) | ||
t.FailNow() | ||
} | ||
|
||
resp.Body.Close() | ||
} | ||
|
||
func TestResolveAddresses(t *testing.T) { | ||
table := map[string]struct { | ||
input []string | ||
want []string | ||
expectError bool | ||
expectMismatch bool | ||
}{ | ||
"Good list": { | ||
input: []string{ | ||
"8.8.8.8:53", | ||
"9.9.9.9:1234", | ||
"2.3.4.5", | ||
}, | ||
want: []string{ | ||
"8.8.8.8:53", | ||
"9.9.9.9:1234", | ||
"2.3.4.5:53", | ||
}, | ||
expectError: false, | ||
expectMismatch: false, | ||
}, | ||
"Mismatch list": { | ||
input: []string{ | ||
"9.9.9.9:1234", | ||
}, | ||
want: []string{ | ||
"9.9.9.9:53", | ||
}, | ||
expectError: false, | ||
expectMismatch: true, | ||
}, | ||
"Parse error list": { | ||
input: []string{ | ||
"abcd.com:53", | ||
}, | ||
expectError: true, | ||
}, | ||
} | ||
for subtest, tdata := range table { | ||
t.Run(subtest, func(t *testing.T) { | ||
addrs, err := normalizeResolverAddresses(tdata.input) | ||
if tdata.expectError { | ||
if err == nil { | ||
t.Error("expected error, got none") | ||
} | ||
return | ||
|
||
} | ||
|
||
if err != nil { | ||
t.Errorf("expected nil error, got: %s", err) | ||
return | ||
} | ||
|
||
match := true | ||
if len(tdata.want) != len(addrs) { | ||
match = false | ||
} else { | ||
for i, addr := range addrs { | ||
if addr != tdata.want[i] { | ||
match = false | ||
break | ||
} | ||
} | ||
} | ||
if !tdata.expectMismatch && !match { | ||
t.Errorf("unexpected mismatch, input: %#v, want: %#v", addrs, tdata.want) | ||
} | ||
|
||
}) | ||
} | ||
|
||
} | ||
|
||
func TestResolverOverflow(t *testing.T) { | ||
res := &resolver{ | ||
addresses: []string{"8.8.8.8:53", "9.9.9.9:53"}, | ||
idx: ^uint64(0), | ||
} | ||
_ = res.address() | ||
if res.idx != 0 { | ||
t.Error("overflow not handled gracefully") | ||
} | ||
// throw away another one to make sure we're back to 0 | ||
_ = res.address() | ||
for i := 0; i < 5; i++ { | ||
addr := res.address() | ||
if expectedAddr := res.addresses[i%len(res.addresses)]; expectedAddr != addr { | ||
t.Errorf("address mismatch, have: %s, want: %s", addr, expectedAddr) | ||
} | ||
} | ||
} |