-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #308 from nathanejohnson/feature/override_resolvers
attack: Add -resolvers flag to command
- Loading branch information
Showing
5 changed files
with
300 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// +build !windows | ||
|
||
package main | ||
|
||
import "flag" | ||
|
||
func systemSpecificFlags(fs *flag.FlagSet, opts *attackOpts) { | ||
fs.Var(&opts.resolvers, "resolvers", "List of addresses (ip:port) to use for DNS resolution. Disables use of local system DNS. (comma separated list)") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
package main | ||
|
||
import "flag" | ||
|
||
func systemSpecificFlags(fs *flag.FlagSet, opts *attackOpts) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package resolver | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"net" | ||
"strconv" | ||
"strings" | ||
"sync/atomic" | ||
) | ||
|
||
type resolver struct { | ||
addrs []string | ||
dialer *net.Dialer | ||
idx uint64 | ||
} | ||
|
||
// NewResolver - create a new instance of a dns resolver for plugging | ||
// into net.DefaultResolver. Addresses should be a list of | ||
// ip addrs and optional port numbers, separated by colon. | ||
// For example: 1.2.3.4:53 and 1.2.3.4 are both valid. In the absence | ||
// of a port number, 53 will be used instead. | ||
func NewResolver(addrs []string) (*net.Resolver, error) { | ||
if len(addrs) == 0 { | ||
return nil, errors.New("must specify at least resolver address") | ||
} | ||
cleanAddrs, err := normalizeAddrs(addrs) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &net.Resolver{ | ||
PreferGo: true, | ||
Dial: (&resolver{addrs: cleanAddrs, dialer: &net.Dialer{}}).dial, | ||
}, nil | ||
} | ||
|
||
func normalizeAddrs(addrs []string) ([]string, error) { | ||
normal := make([]string, len(addrs)) | ||
for i, addr := range addrs { | ||
|
||
// if addr has no port, give it 53 | ||
if !strings.Contains(addr, ":") { | ||
addr += ":53" | ||
} | ||
|
||
// validate addr is a valid host:port | ||
host, portstr, err := net.SplitHostPort(addr) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// validate valid port. | ||
port, err := strconv.ParseUint(portstr, 10, 16) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if port <= 0 { | ||
return nil, errors.New("invalid port") | ||
} | ||
|
||
// make sure host is an ip. | ||
ip := net.ParseIP(host) | ||
if ip == nil { | ||
return nil, fmt.Errorf("host %s is not an IP address", host) | ||
} | ||
|
||
normal[i] = addr | ||
} | ||
return normal, nil | ||
} | ||
|
||
// ignore the third parameter, as this represents the dns server address that | ||
// we are overriding. | ||
func (r *resolver) dial(ctx context.Context, network, _ string) (net.Conn, error) { | ||
return r.dialer.DialContext(ctx, network, r.address()) | ||
} | ||
|
||
func (r *resolver) address() string { | ||
return r.addrs[atomic.AddUint64(&r.idx, 1)%uint64(len(r.addrs))] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
package resolver | ||
|
||
import ( | ||
"fmt" | ||
"io/ioutil" | ||
"net" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/miekg/dns" | ||
) | ||
|
||
const ( | ||
fakeDomain = "acme.notadomain" | ||
) | ||
|
||
func TestResolveMiekg(t *testing.T) { | ||
|
||
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { | ||
m := &dns.Msg{} | ||
m.SetReply(r) | ||
localIP := net.ParseIP("127.0.0.1") | ||
defer func() { | ||
err := w.WriteMsg(m) | ||
if err != nil { | ||
t.Logf("got error writing dns message: %s", err) | ||
} | ||
}() | ||
if len(r.Question) == 0 { | ||
m.RecursionAvailable = true | ||
m.SetRcode(r, dns.RcodeRefused) | ||
return | ||
} | ||
|
||
q := r.Question[0] | ||
|
||
if q.Name == fakeDomain+"." { | ||
m.Answer = []dns.RR{&dns.A{ | ||
Hdr: dns.RR_Header{ | ||
Name: q.Name, | ||
Rrtype: dns.TypeA, | ||
Class: dns.ClassINET, | ||
Ttl: 1, | ||
}, | ||
A: localIP, | ||
}} | ||
} else { | ||
m.SetRcode(r, dns.RcodeNameError) | ||
} | ||
}) | ||
const payload = "there is no cloud, just someone else's computer" | ||
|
||
done := make(chan struct{}) | ||
|
||
ds := dns.Server{ | ||
Addr: "127.0.0.1:0", | ||
Net: "udp", | ||
UDPSize: dns.MinMsgSize, | ||
ReadTimeout: 2 * time.Second, | ||
WriteTimeout: 2 * time.Second, | ||
// Unsafe instructs the server to disregard any sanity checks and directly hand the message to | ||
// the handler. It will specifically not check if the query has the QR bit not set. | ||
Unsafe: false, | ||
NotifyStartedFunc: func() { close(done) }, | ||
} | ||
|
||
go func() { | ||
err := ds.ListenAndServe() | ||
if err != nil { | ||
t.Logf("got error during dns ListenAndServe: %s", err) | ||
} | ||
}() | ||
|
||
defer func() { | ||
_ = ds.Shutdown() | ||
}() | ||
|
||
// wait for notify function to be called, ensuring ds.PacketConn is not nil. | ||
<-done | ||
|
||
res, err := NewResolver([]string{ds.PacketConn.LocalAddr().String()}) | ||
if err != nil { | ||
t.Errorf("error from NewResolver: %s", err) | ||
return | ||
} | ||
net.DefaultResolver = res | ||
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
fmt.Fprintln(w, payload) | ||
})) | ||
defer ts.Close() | ||
|
||
tsurl, _ := url.Parse(ts.URL) | ||
|
||
_, hport, err := net.SplitHostPort(tsurl.Host) | ||
if err != nil { | ||
t.Errorf("could not parse port from httptest url %s: %s", ts.URL, err) | ||
return | ||
} | ||
tsurl.Host = net.JoinHostPort(fakeDomain, hport) | ||
resp, err := http.Get(tsurl.String()) | ||
if err != nil { | ||
t.Errorf("failed resolver round trip: %s", err) | ||
return | ||
} | ||
body, err := ioutil.ReadAll(resp.Body) | ||
if err != nil { | ||
t.Errorf("failed to read respose body") | ||
return | ||
} | ||
if strings.TrimSpace(string(body)) != payload { | ||
t.Errorf("body mismatch, got: '%s', expected: '%s'", body, payload) | ||
} | ||
} | ||
|
||
func TestResolveAddresses(t *testing.T) { | ||
table := map[string]struct { | ||
input []string | ||
want []string | ||
expectError bool | ||
expectMismatch bool | ||
}{ | ||
"Good list": { | ||
input: []string{ | ||
"8.8.8.8:53", | ||
"9.9.9.9:1234", | ||
"2.3.4.5", | ||
}, | ||
want: []string{ | ||
"8.8.8.8:53", | ||
"9.9.9.9:1234", | ||
"2.3.4.5:53", | ||
}, | ||
expectError: false, | ||
expectMismatch: false, | ||
}, | ||
"Mismatch list": { | ||
input: []string{ | ||
"9.9.9.9:1234", | ||
}, | ||
want: []string{ | ||
"9.9.9.9:53", | ||
}, | ||
expectError: false, | ||
expectMismatch: true, | ||
}, | ||
"Parse error list": { | ||
input: []string{ | ||
"abcd.com:53", | ||
}, | ||
expectError: true, | ||
}, | ||
} | ||
for subtest, tdata := range table { | ||
t.Run(subtest, func(t *testing.T) { | ||
addrs, err := normalizeAddrs(tdata.input) | ||
if tdata.expectError { | ||
if err == nil { | ||
t.Error("expected error, got none") | ||
} | ||
return | ||
|
||
} | ||
|
||
if err != nil { | ||
t.Errorf("expected nil error, got: %s", err) | ||
return | ||
} | ||
|
||
match := true | ||
if len(tdata.want) != len(addrs) { | ||
match = false | ||
} else { | ||
for i, addr := range addrs { | ||
if addr != tdata.want[i] { | ||
match = false | ||
break | ||
} | ||
} | ||
} | ||
if !tdata.expectMismatch && !match { | ||
t.Errorf("unexpected mismatch, input: %#v, want: %#v", addrs, tdata.want) | ||
} | ||
|
||
}) | ||
} | ||
|
||
} |