Skip to content

Commit

Permalink
Refactor Attacker.hit and simplify tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tsenart committed Jan 3, 2015
1 parent 7734f2b commit 7d37fa4
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 123 deletions.
41 changes: 23 additions & 18 deletions lib/attack.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package vegeta
import (
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
Expand Down Expand Up @@ -167,42 +166,48 @@ func (a *Attacker) Attack(tr Targeter, rate uint64, du time.Duration) chan *Resu
func (a *Attacker) Stop() { close(a.stop) }

func (a *Attacker) hit(tr Targeter, tm time.Time) *Result {
res := Result{Timestamp: tm}
defer func() { res.Latency = time.Since(tm) }()
var (
res = Result{Timestamp: tm}
err error
)

defer func() {
res.Latency = time.Since(tm)
if err != nil {
res.Error = err.Error()
}
}()

tgt, err := tr()
if err != nil {
res.Error = err.Error()
return &res
}

req, err := tgt.Request()
if err != nil {
res.Error = err.Error()
return &res
}

r, err := a.client.Do(req)
if err != nil {
// ignore redirect errors when the user set --redirects=NoFollow
if a.redirects != NoFollow || !strings.Contains(err.Error(), "stopped after") {
res.Error = err.Error()
if a.redirects == NoFollow && strings.Contains(err.Error(), "stopped after") {
err = nil
}
return &res
}
defer r.Body.Close()
r.Body.Close()

res.BytesOut = uint64(req.ContentLength)
res.Code = uint16(r.StatusCode)
if req.ContentLength != -1 {
res.BytesOut = uint64(req.ContentLength)
}

if body, err := ioutil.ReadAll(r.Body); err != nil {
res.Error = err.Error()
} else if res.BytesIn = uint64(len(body)); res.Code < 200 || res.Code >= 400 {
if len(body) != 0 {
res.Error = string(body)
} else {
res.Error = r.Status
}
if r.ContentLength != -1 {
res.BytesIn = uint64(r.ContentLength)
}

if res.Code = uint16(r.StatusCode); res.Code < 200 || res.Code >= 400 {
res.Error = r.Status
}

return &res
Expand Down
155 changes: 50 additions & 105 deletions lib/attack_test.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
package vegeta

import (
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"sync/atomic"
"testing"
"time"
)

func TestAttackRate(t *testing.T) {
t.Parallel()

server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
)
Expand All @@ -24,167 +24,112 @@ func TestAttackRate(t *testing.T) {
for _ = range atk.Attack(tr, rate, 1*time.Second) {
hits++
}
if hits != rate {
t.Fatalf("Wrong number of hits: want %d, got %d\n", rate, hits)
if got, want := hits, rate; got != want {
t.Fatalf("got: %v, want: %v", rate, hits)
}
}

func TestDefaultAttackerCertConfig(t *testing.T) {
t.Parallel()

server := httptest.NewTLSServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
)
func TestTLSConfig(t *testing.T) {
atk := NewAttacker()
request, _ := http.NewRequest("GET", server.URL, nil)
_, err := atk.client.Do(request)
if err != nil && strings.Contains(err.Error(), "x509: certificate signed by unknown authority") {
t.Errorf("Invalid certificates should be ignored: Got `%s`", err)
got := atk.client.Transport.(*http.Transport).TLSClientConfig
if want := (&tls.Config{InsecureSkipVerify: true}); !reflect.DeepEqual(got, want) {
t.Fatalf("got: %+v, want: %+v", got, want)
}
}

func TestRedirects(t *testing.T) {
t.Parallel()

var servers [2]*httptest.Server
var hits uint64

for i := range servers {
servers[i] = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddUint64(&hits, 1)
http.Redirect(w, r, servers[(i+1)%2].URL, 302)
}),
)
}

atk := NewAttacker(Redirects(2))
tr := NewStaticTargeter(&Target{Method: "GET", URL: servers[0].URL})
var rate uint64 = 10
results := atk.Attack(tr, rate, 1*time.Second)

want := fmt.Sprintf("stopped after %d redirects", 2)
for result := range results {
if !strings.Contains(result.Error, want) {
t.Fatalf("Expected error to be: %s, Got: %s", want, result.Error)
}
}

if want, got := rate*(2+1), hits; want != got {
t.Fatalf("Expected hits to be: %d, Got: %d", want, got)
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/redirect", 302)
}),
)
redirects := 2
atk := NewAttacker(Redirects(redirects))
tr := NewStaticTargeter(&Target{Method: "GET", URL: server.URL})
res := atk.hit(tr, time.Now())
want := fmt.Sprintf("stopped after %d redirects", redirects)
if got := res.Error; !strings.HasSuffix(got, want) {
t.Fatalf("want: '%v' in '%v'", want, got)
}
}

func TestMarkRedirectsAsSuccess(t *testing.T) {
t.Parallel()

var server *httptest.Server
var hits uint64

server = httptest.NewServer(
func TestNoFollow(t *testing.T) {
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddUint64(&hits, 1)
http.Redirect(w, r, "/redirect-here", 302)
}),
)

atk := NewAttacker(Redirects(-1))
atk := NewAttacker(Redirects(NoFollow))
tr := NewStaticTargeter(&Target{Method: "GET", URL: server.URL})
var rate uint64 = 10
results := atk.Attack(tr, rate, 1*time.Second)

for result := range results {
if result.Error != "" {
t.Fatalf("Unexpected error: %s", result.Error)
}
}

if want, got := rate, hits; want != got {
t.Fatalf("Expected hits to be: %d, Got: %d", want, got)
if res := atk.hit(tr, time.Now()); res.Error != "" {
t.Fatalf("got err: %v", res.Error)
}
}

func TestTimeout(t *testing.T) {
t.Parallel()

server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-time.After(20 * time.Millisecond)
}),
)

atk := NewAttacker(Timeout(10 * time.Millisecond))
tr := NewStaticTargeter(&Target{Method: "GET", URL: server.URL})
results := atk.Attack(tr, 1, 1*time.Second)

res := atk.hit(tr, time.Now())
want := "net/http: timeout awaiting response headers"
for result := range results {
if !strings.Contains(result.Error, want) {
t.Fatalf("Expected error to be: %s, Got: %s", want, result.Error)
}
if got := res.Error; !strings.HasSuffix(got, want) {
t.Fatalf("want: '%v' in '%v'", want, got)
}
}

func TestLocalAddr(t *testing.T) {
t.Parallel()

addr, err := net.ResolveIPAddr("ip", "127.0.0.1")
if err != nil {
t.Fatal(err)
}

server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
if got, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
t.Fatal(err)
}

if host != addr.String() {
t.Fatalf("Wrong source address. Want %s, Got %s", addr, host)
} else if want := addr.String(); got != want {
t.Fatalf("wrong source address. got %v, want: %v", got, want)
}
}),
)

atk := NewAttacker(LocalAddr(*addr))
tr := NewStaticTargeter(&Target{Method: "GET", URL: server.URL})

for result := range atk.Attack(tr, 1, 1*time.Second) {
if result.Error != "" {
t.Fatal(result.Error)
}
}
atk.hit(tr, time.Now())
}

func TestKeepAlive(t *testing.T) {
t.Parallel()

atk := NewAttacker(KeepAlive(false))

if atk.dialer.KeepAlive != 0 {
t.Fatalf("Dialer KeepAlive is not disabled. Want 0. Got %d", atk.dialer.KeepAlive)
if got, want := atk.dialer.KeepAlive, time.Duration(0); got != want {
t.Fatalf("got: %v, want: %v", got, want)
}

disableKeepAlive := atk.client.Transport.(*http.Transport).DisableKeepAlives
if disableKeepAlive == false {
t.Fatalf("Transport DisableKeepAlives is not enabled. Want true. Got %t", disableKeepAlive)
got := atk.client.Transport.(*http.Transport).DisableKeepAlives
if want := true; got != want {
t.Fatalf("got: %v, want: %v", got, want)
}
}

func TestStatusCodeErrors(t *testing.T) {
t.Parallel()

server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}),
)

atk := NewAttacker()
tr := NewStaticTargeter(&Target{Method: "GET", URL: server.URL})
for result := range atk.Attack(tr, 1, 1*time.Second) {
if got, want := result.Error, "400 Bad Request"; got != want {
t.Fatalf("got: %s, want: %s", got, want)
}
res := atk.hit(tr, time.Now())
if got, want := res.Error, "400 Bad Request"; got != want {
t.Fatalf("got: %v, want: %v", got, want)
}
}

func TestBadTargeterError(t *testing.T) {
atk := NewAttacker()
tr := func() (*Target, error) { return nil, io.EOF }
res := atk.hit(tr, time.Now())
if got, want := res.Error, io.EOF.Error(); got != want {
t.Fatalf("got: %v, want: %v", got, want)
}
}

0 comments on commit 7d37fa4

Please sign in to comment.