Skip to content

Commit

Permalink
Merge pull request #308 from nathanejohnson/feature/override_resolvers
Browse files Browse the repository at this point in the history
attack: Add -resolvers flag to command
  • Loading branch information
tsenart authored Aug 30, 2018
2 parents 43c736a + 3b39fee commit fe1b142
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 1 deletion.
13 changes: 12 additions & 1 deletion attack.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"os"
"os/signal"
"strings"
"time"

"github.com/tsenart/vegeta/internal/resolver"
vegeta "github.com/tsenart/vegeta/lib"
)

Expand All @@ -25,7 +27,6 @@ func attackCmd() command {
rate: vegeta.Rate{Freq: 50, Per: time.Second},
maxBody: vegeta.DefaultMaxBody,
}

fs.StringVar(&opts.name, "name", "", "Attack name")
fs.StringVar(&opts.targetsf, "targets", "stdin", "Targets file")
fs.StringVar(&opts.format, "format", vegeta.HTTPTargetFormat,
Expand All @@ -49,6 +50,7 @@ func attackCmd() command {
fs.Var(&opts.headers, "header", "Request header")
fs.Var(&opts.laddr, "laddr", "Local IP address")
fs.BoolVar(&opts.keepalive, "keepalive", true, "Use persistent connections")
systemSpecificFlags(fs, opts)

return command{fs, func(args []string) error {
fs.Parse(args)
Expand Down Expand Up @@ -85,6 +87,7 @@ type attackOpts struct {
headers headers
laddr localAddr
keepalive bool
resolvers csl
}

// attack validates the attack arguments, sets up the
Expand All @@ -94,6 +97,14 @@ func attack(opts *attackOpts) (err error) {
return errZeroRate
}

if len(opts.resolvers) > 0 {
res, err := resolver.NewResolver(opts.resolvers)
if err != nil {
return err
}
net.DefaultResolver = res
}

files := map[string]io.Reader{}
for _, filename := range []string{opts.targetsf, opts.bodyf} {
if filename == "" {
Expand Down
9 changes: 9 additions & 0 deletions attack_nonwindows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// +build !windows

package main

import "flag"

func systemSpecificFlags(fs *flag.FlagSet, opts *attackOpts) {
fs.Var(&opts.resolvers, "resolvers", "List of addresses (ip:port) to use for DNS resolution. Disables use of local system DNS. (comma separated list)")
}
5 changes: 5 additions & 0 deletions attack_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package main

import "flag"

func systemSpecificFlags(fs *flag.FlagSet, opts *attackOpts) {}
82 changes: 82 additions & 0 deletions internal/resolver/resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package resolver

import (
"context"
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync/atomic"
)

type resolver struct {
addrs []string
dialer *net.Dialer
idx uint64
}

// NewResolver - create a new instance of a dns resolver for plugging
// into net.DefaultResolver. Addresses should be a list of
// ip addrs and optional port numbers, separated by colon.
// For example: 1.2.3.4:53 and 1.2.3.4 are both valid. In the absence
// of a port number, 53 will be used instead.
func NewResolver(addrs []string) (*net.Resolver, error) {
if len(addrs) == 0 {
return nil, errors.New("must specify at least resolver address")
}
cleanAddrs, err := normalizeAddrs(addrs)
if err != nil {
return nil, err
}
return &net.Resolver{
PreferGo: true,
Dial: (&resolver{addrs: cleanAddrs, dialer: &net.Dialer{}}).dial,
}, nil
}

func normalizeAddrs(addrs []string) ([]string, error) {
normal := make([]string, len(addrs))
for i, addr := range addrs {

// if addr has no port, give it 53
if !strings.Contains(addr, ":") {
addr += ":53"
}

// validate addr is a valid host:port
host, portstr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}

// validate valid port.
port, err := strconv.ParseUint(portstr, 10, 16)
if err != nil {
return nil, err
}

if port <= 0 {
return nil, errors.New("invalid port")
}

// make sure host is an ip.
ip := net.ParseIP(host)
if ip == nil {
return nil, fmt.Errorf("host %s is not an IP address", host)
}

normal[i] = addr
}
return normal, nil
}

// ignore the third parameter, as this represents the dns server address that
// we are overriding.
func (r *resolver) dial(ctx context.Context, network, _ string) (net.Conn, error) {
return r.dialer.DialContext(ctx, network, r.address())
}

func (r *resolver) address() string {
return r.addrs[atomic.AddUint64(&r.idx, 1)%uint64(len(r.addrs))]
}
192 changes: 192 additions & 0 deletions internal/resolver/resolver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package resolver

import (
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/miekg/dns"
)

const (
fakeDomain = "acme.notadomain"
)

func TestResolveMiekg(t *testing.T) {

dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
m := &dns.Msg{}
m.SetReply(r)
localIP := net.ParseIP("127.0.0.1")
defer func() {
err := w.WriteMsg(m)
if err != nil {
t.Logf("got error writing dns message: %s", err)
}
}()
if len(r.Question) == 0 {
m.RecursionAvailable = true
m.SetRcode(r, dns.RcodeRefused)
return
}

q := r.Question[0]

if q.Name == fakeDomain+"." {
m.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 1,
},
A: localIP,
}}
} else {
m.SetRcode(r, dns.RcodeNameError)
}
})
const payload = "there is no cloud, just someone else's computer"

done := make(chan struct{})

ds := dns.Server{
Addr: "127.0.0.1:0",
Net: "udp",
UDPSize: dns.MinMsgSize,
ReadTimeout: 2 * time.Second,
WriteTimeout: 2 * time.Second,
// Unsafe instructs the server to disregard any sanity checks and directly hand the message to
// the handler. It will specifically not check if the query has the QR bit not set.
Unsafe: false,
NotifyStartedFunc: func() { close(done) },
}

go func() {
err := ds.ListenAndServe()
if err != nil {
t.Logf("got error during dns ListenAndServe: %s", err)
}
}()

defer func() {
_ = ds.Shutdown()
}()

// wait for notify function to be called, ensuring ds.PacketConn is not nil.
<-done

res, err := NewResolver([]string{ds.PacketConn.LocalAddr().String()})
if err != nil {
t.Errorf("error from NewResolver: %s", err)
return
}
net.DefaultResolver = res

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, payload)
}))
defer ts.Close()

tsurl, _ := url.Parse(ts.URL)

_, hport, err := net.SplitHostPort(tsurl.Host)
if err != nil {
t.Errorf("could not parse port from httptest url %s: %s", ts.URL, err)
return
}
tsurl.Host = net.JoinHostPort(fakeDomain, hport)
resp, err := http.Get(tsurl.String())
if err != nil {
t.Errorf("failed resolver round trip: %s", err)
return
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("failed to read respose body")
return
}
if strings.TrimSpace(string(body)) != payload {
t.Errorf("body mismatch, got: '%s', expected: '%s'", body, payload)
}
}

func TestResolveAddresses(t *testing.T) {
table := map[string]struct {
input []string
want []string
expectError bool
expectMismatch bool
}{
"Good list": {
input: []string{
"8.8.8.8:53",
"9.9.9.9:1234",
"2.3.4.5",
},
want: []string{
"8.8.8.8:53",
"9.9.9.9:1234",
"2.3.4.5:53",
},
expectError: false,
expectMismatch: false,
},
"Mismatch list": {
input: []string{
"9.9.9.9:1234",
},
want: []string{
"9.9.9.9:53",
},
expectError: false,
expectMismatch: true,
},
"Parse error list": {
input: []string{
"abcd.com:53",
},
expectError: true,
},
}
for subtest, tdata := range table {
t.Run(subtest, func(t *testing.T) {
addrs, err := normalizeAddrs(tdata.input)
if tdata.expectError {
if err == nil {
t.Error("expected error, got none")
}
return

}

if err != nil {
t.Errorf("expected nil error, got: %s", err)
return
}

match := true
if len(tdata.want) != len(addrs) {
match = false
} else {
for i, addr := range addrs {
if addr != tdata.want[i] {
match = false
break
}
}
}
if !tdata.expectMismatch && !match {
t.Errorf("unexpected mismatch, input: %#v, want: %#v", addrs, tdata.want)
}

})
}

}

0 comments on commit fe1b142

Please sign in to comment.